maxxxzdn / jax-geometric

Implementation of various equivariant models in JAX
MIT License
10 stars 0 forks source link

Profiling code for low level CUDA analysis ? #1

Open mitkotak opened 1 month ago

mitkotak commented 1 month ago

Coming here after reading this twitter thread and the results look super cool !

Is it possible to share the profiling scripts for both of the plots referenced there ? I was interested in plugging it into NVIDIA's Nsight Systems and look at the low level CUDA kernels (similar to what was done here).

Thanks for sharing this work !

Edit: Found this helpful thread about benchmarking torch.compile which might be relevant.

maxxxzdn commented 1 month ago

Hi Mit,

Sure, I will push the code soon, likely on Friday/Saturday. Thanks for sharing!

mitkotak commented 1 month ago

Apologies for reopening this but was wondering whether you have timing/benchmarking scripts that I can quickly throw onto my RTX A5500 to see what the numbers look like and hopefully do an apples to apples comparison with your plots. Thanks again for doing this !

mitkotak commented 1 month ago

Also was trying to get torch.compile to work with fullgraph=True and as you said on Twitter there seem to be some errors. Is it Ok if I file issues for them so that other folks can help out with debugging ?

Would be cool to comparejax.jit vs torch.compile for these type of implementations (Don't worry I am not trying to start another Twitter war :D )

mitkotak commented 1 month ago

And FYI for SEGNN and EGNN there's also these implementations that would be interesting to compare against.

maxxxzdn commented 1 month ago

I just added the script for CEGNN in JAX.

Concerning opening issues, please go ahead, I would love people to help. I was recently playing with CEGNNs and managed to indicate operations that make torch struggle with compilation:

I will update the codebase if I manage to resolve them :)