TorchMD-Net does not really understand anything other than float32.
For production runs this is probably fine, but I believe it would be useful to be able to pass a dtype argument to TorchMD-Net and run the full model in double for testing/development purposes.
For instance, to check gradients using torch.autograd.gradcheck.
This amounts to adding a bunch of dtype=dtype here and there.
What do you think?
TorchMD-Net does not really understand anything other than float32. For production runs this is probably fine, but I believe it would be useful to be able to pass a dtype argument to TorchMD-Net and run the full model in double for testing/development purposes. For instance, to check gradients using torch.autograd.gradcheck.
This amounts to adding a bunch of
dtype=dtype
here and there. What do you think?