Closed hatemhelal closed 5 months ago
I experimented with the different torch.compile
mode options and measured approximate speedups over native eager mode as:
mode | speedup (a10g) |
---|---|
default | 1.4 |
reduce-overhead | 1.6 |
max-autotune | 1.6 |
To get these values I ran:
pytest -s -k test_inference_speedup
Note, still need to investigate the following warning:
skipping cudagraphs due to complex input striding
which is seen for the non-default modes and I suspect is due to how the input batches are created in the test case.
This PR add some test cases that use
torch.compile
and a new modulemace.tools.compile
that contains some helper utilities for MACE compatibilty withtorch.compile
Some of the changes needed include:
torch.jit.script
annotations from the scatter-reduce implementations. This is necessary as the compiled script functions are not compatible with the inductor backend.e3nn.set_optimization_defaults(jit_script_fx=False)
ahead of creating the model instance. This can be managed with thedisable_e3nn_codegen
context manager.simplify_if_compile
prepare
function which manages creating the model without e3nn codegen and applies the symbolic tracing simplification to registered modules.Note that the remaining graph break in the compiled inference model is due to using autograd to evaluate the forces. This might be possible to fix but I expect it would be easier to do in another PR.