rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

Minor API change #13

Closed zwei-beiner closed 6 months ago

zwei-beiner commented 6 months ago

The neural net I'm using torch2jax for uses torch.autograd to compute gradients internally. To get torch2jax working, I had to delete the line with torch.no_grad() in api.py (2 occurrences) and in gradients.py (1 occurrence).

Would it be possible to make the use of with torch.no_grad() optional at the user-level? I suppose the only reason to use it here is for speedup during output shape inference. Maybe this could be something like uses_gradients_internally with default value False.

rdyro commented 6 months ago

Great catch! This is a bug, the no_grad option shouldn't be there at all. I've just updated it on the multi-gpu branch!

zwei-beiner commented 6 months ago

Thanks a lot for the quick response, closing this now.

zwei-beiner commented 6 months ago

Just noticed that the no_grad is still in api.py in lines 69 and 206, so this should be removed as well.

rdyro commented 6 months ago

This is incorporated now!