diff --git a/include/neural_net/neural_net.py b/include/neural_net/neural_net.py index cc04a98..fc0ff51 100644 --- a/include/neural_net/neural_net.py +++ b/include/neural_net/neural_net.py @@ -211,6 +211,8 @@ def __init__(self, "mps" if torch.backends.mps.is_available() else "cpu") self.num_workers = num_workers gsim_logger.info(f"Using {self.device_type} device") + if self.device_type == 'cuda': + torch.set_float32_matmul_precision('high') if nn_folder is None: gsim_logger.warning("* " * 50) gsim_logger.warning( @@ -378,7 +380,12 @@ def _get_loss(self, data: tuple[InputType, TargetType], input_batch = self._move_to_device(input_batch) targets_batch = self._move_to_device(targets_batch) - output_batch = self(input_batch) + with torch.autocast('cuda', dtype=torch.bfloat16, enabled=self.device_type == 'cuda'): + output_batch = self(input_batch) + # Loss computed in fp32 regardless of autocast — avoids underflow/NaN + # for MAPE/MAE when output is bf16. + if isinstance(output_batch, torch.Tensor): + output_batch = output_batch.float() loss = f_loss(output_batch, targets_batch) if isinstance(targets_batch, torch.Tensor): @@ -641,7 +648,9 @@ def make_output(l_out, output_class): # Run the forward pass input_batch = self._move_to_device(input_batch) - output_batch = self._move_to_cpu(self(input_batch)) + with torch.autocast('cuda', dtype=torch.bfloat16, enabled=self.device_type == 'cuda'): + output_batch = self(input_batch) + output_batch = self._move_to_cpu(output_batch) if unnormalize and self.normalizer is not None: output_batch = self.normalizer.unnormalize_output_batch( output_batch)