rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
39 stars 1 forks source link

gcc version? #14

Closed rkruegs123 closed 6 months ago

rkruegs123 commented 6 months ago

Hi -- thank you for this amazing package. I am making excellent use of it on some systems, but am having trouble with compilatoin on other systems -- which version of gcc do you use for testing?

rdyro commented 6 months ago

Hey, thanks!

I'm using the PyTorch dynamic compilation utilities, which mostly use ninja under the hood. My own test system has gcc-12, but I'm not entirely sure which compiler PyTorch retrieves (I also have gcc-11).

For the official torch module performing the compilation wrapping, take a look at torch.utils.cpp_extension

For more details on how the extension is compiled, take a look at this logic: torch2jax/compile.py

Let me know if this helps!

rkruegs123 commented 6 months ago

Thank you for the quick response, and thank you again for your wonderful repository!