Skip to content

Enable CUDA mixed precision in NeuralNet#15

Open
avidaldo wants to merge 1 commit intofachu000:flexible_trainingfrom
avidaldo:improvements__cuda
Open

Enable CUDA mixed precision in NeuralNet#15
avidaldo wants to merge 1 commit intofachu000:flexible_trainingfrom
avidaldo:improvements__cuda

Conversation

@avidaldo
Copy link
Copy Markdown

Summary

  • Enable TF32 matrix multiplication on CUDA (torch.set_float32_matmul_precision('high')).
  • Run forward passes with CUDA autocast in bfloat16 for training and prediction.
  • Keep loss computation in fp32 to avoid bf16 instability for MAE/MAPE style losses.

Why

  • Improve throughput and Tensor Core utilization on CUDA while preserving training stability.
  • Keep non-CUDA behavior unchanged (enabled=self.device_type == 'cuda').

Scope

  • include/neural_net/neural_net.py only

Validation

  • Local test execution could not be run in this environment because pytest is not installed (No module named pytest).

Copilot AI review requested due to automatic review settings March 14, 2026 16:49
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces CUDA-focused mixed-precision behavior in NeuralNet to improve throughput by enabling TF32 matmul precision and running forward passes under CUDA autocast (bf16), while attempting to keep loss computation in fp32 for stability.

Changes:

  • Enable TF32 matmul precision on CUDA via torch.set_float32_matmul_precision('high').
  • Wrap training/eval forward passes in torch.autocast(..., dtype=torch.bfloat16) when running on CUDA.
  • Cast model outputs to fp32 before loss computation (currently only for tensor outputs).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants