rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
39 stars 1 forks source link

Use of new JAX Performance Flags? #10

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

Could setting any of these flags cause issues with functions under Torch2Jax (as you know, I use Keops a lot)

https://github.com/google/jax/blob/664e834784a35117c244030c28862177dcfc76f0/docs/gpu_performance_tips.md

rdyro commented 1 year ago

The flags specifically are these if I understand correctly:

Triton should no cause any issues.

Async options, the latter 3 flags are trickier to judge. torch2jax uses a conservative synchronization approach, so I would not expect it to cause any issues.

My best guess is those flags should probably not cause any issues.

I'll rerun the test suite with these flags later this week to give a more definitive answer. I'll leave the issue open until then.

rdyro commented 1 year ago

The tests pass with those flags included, so hopefully, you will have no problems.

Please reopen this issue if you encounter any issues.