Closed smg3d closed 4 hours ago
Yes, the time difference is compilation (see, for example, https://jax.readthedocs.io/en/latest/profiling.html). You can use a persistent compilation cache (https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html), though jax will still need to trace, so some overhead remains. 22s sounds reasonable.
XLA can do autotuning, which involves running kernels on GPU (note we disable one such pass in https://github.com/google-deepmind/alphafold3/blob/main/docker/Dockerfile#L59; I haven't checked if there are other autotuning passes triggered though). Compilation is otherwise done on CPU.
In the Performance documentation:Compilation Buckets it is mention that we want to preferably use a single compilation of the model.
Question 1: Does the time for the compilation of the model correspond to most of the difference in inference time between the first seed and the other seeds. In the example below, does the compilation take about 22 sec?
Question 2 : Is the compilation of the model done on GPU or CPU?