CederGroupHub / chgnet

Pretrained universal neural network potential for charge-informed atomistic modeling https://chgnet.lbl.gov
https://doi.org/10.1038/s42256-023-00716-3
Other
215 stars 55 forks source link

Revert `torch.det` for MPS support #132

Closed tsihyoung closed 4 months ago

tsihyoung commented 4 months ago

Summary

Use torch.linalg.cross to suppress deprecated warning raised by torch >= 2.2.

Also, since torch.det utilises LU decomposition to calculate the determinant, I prefer to use the more direct way (i.e., mixed product) to calculate volumes even if torch.det gets MPS support in the future.

janosh commented 4 months ago

thanks @tsihyoung! i think this was superseded by #131? let me know if i missed something.

tsihyoung commented 4 months ago

thanks @tsihyoung! i think this was superseded by #131? let me know if i missed something.

I think they are not the same. I replaced torch.cross with torch.linalg.cross, as torch.cross without dim arg is deprecated.

If you prefer to keep torch.cross, then we need torch.cross(lattice[1], lattice[2], dim=-1) to suppress the UserWarning.

janosh commented 4 months ago

@tsihyoung thanks for pointing that out. i'll change it to torch.linalg.cross in #133