rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

Enabling Mixed Precision Training for Your Model #252

Open xwqtju opened 5 months ago

xwqtju commented 5 months ago

Hi,

I hope this message finds you well. I have been working with your torchdiffeq and have found it extremely valuable for my project. I am particularly interested in leveraging mixed precision (half-precision) training to potentially increase the training speed and efficiency on my hardware.

However, I have encountered some difficulties in enabling mixed precision training with your model. It seems that the current implementation does not fully support this feature.

Could you please provide guidance on how to modify the model to support half-precision training using torch.cuda.amp? Specifically, any advice on changes that need to be made in the model architecture or training loop would be greatly appreciated.

Thank you very much for your time and assistance. I look forward to your response.

Best regards,