materialsvirtuallab / matgl

Graph deep learning library for materials
BSD 3-Clause "New" or "Revised" License
256 stars 59 forks source link

Adding Tensor Placement Calls For Ease of Training with PyTorch Lightning #215

Closed melo-gonzo closed 8 months ago

melo-gonzo commented 8 months ago

Summary

This PR adds necessary tensor.to() calls to properly set up new tensors created during training. This is recommended by PyTorch Lightning so that training models can be done in a hardware-agnostic manner and scale to an arbitrary number of devices without changing any other code.

Adding these changes removes the need for using torch.set_default_device("cuda") and setting default generator for dataloaders generator=torch.Generator(device='cuda'). Additionally, these updates ensure num_workers>0 may be used during training, which has been seen in other issues, such as #213, and #105 from MatSci ML.

Major changes:

Added tensor.to() calls to properly move tensors to the correct device during training. The device is selected based on input tensors to the respective functions where new tensors are created.

Checklist

Tip: Install pre-commit hooks to auto-check types and linting before every commit:

pip install -U pre-commit
pre-commit install
codecov[bot] commented 8 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (6759ffb) 98.59% compared to head (81d3035) 98.60%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #215 +/- ## ======================================= Coverage 98.59% 98.60% ======================================= Files 28 28 Lines 1927 1930 +3 ======================================= + Hits 1900 1903 +3 Misses 27 27 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

shyuep commented 8 months ago

Thanks! Very helpful!