Closed melo-gonzo closed 8 months ago
All modified and coverable lines are covered by tests :white_check_mark:
Comparison is base (
6759ffb
) 98.59% compared to head (81d3035
) 98.60%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
Thanks! Very helpful!
Summary
This PR adds necessary tensor.to() calls to properly set up new tensors created during training. This is recommended by PyTorch Lightning so that training models can be done in a hardware-agnostic manner and scale to an arbitrary number of devices without changing any other code.
Adding these changes removes the need for using
torch.set_default_device("cuda")
and setting default generator for dataloadersgenerator=torch.Generator(device='cuda')
. Additionally, these updates ensurenum_workers>0
may be used during training, which has been seen in other issues, such as #213, and #105 from MatSci ML.Major changes:
Added
tensor.to()
calls to properly move tensors to the correct device during training. The device is selected based on input tensors to the respective functions where new tensors are created.Checklist
ruff
.mypy
.duecredit
@due.dcite
decorators to reference relevant papers by DOI (example)Tip: Install
pre-commit
hooks to auto-check types and linting before every commit: