ACEsuit / mace

MACE - Fast and accurate machine learning interatomic potentials with higher order equivariant message passing.
Other
415 stars 157 forks source link

initial torch.compile support (inference only) #300

Closed hatemhelal closed 5 months ago

hatemhelal commented 5 months ago

This PR add some test cases that use torch.compile and a new module mace.tools.compile that contains some helper utilities for MACE compatibilty with torch.compile

Some of the changes needed include:

Note that the remaining graph break in the compiled inference model is due to using autograd to evaluate the forces. This might be possible to fix but I expect it would be easier to do in another PR.

hatemhelal commented 5 months ago

I experimented with the different torch.compile mode options and measured approximate speedups over native eager mode as:

mode speedup (a10g)
default 1.4
reduce-overhead 1.6
max-autotune 1.6

To get these values I ran:

pytest -s -k test_inference_speedup

Note, still need to investigate the following warning:

skipping cudagraphs due to complex input striding

which is seen for the non-default modes and I suspect is due to how the input batches are created in the test case.