Closed zwei-beiner closed 6 months ago
Great catch! This is a bug, the no_grad option shouldn't be there at all. I've just updated it on the multi-gpu branch!
Thanks a lot for the quick response, closing this now.
Just noticed that the no_grad
is still in api.py
in lines 69 and 206, so this should be removed as well.
This is incorporated now!
The neural net I'm using
torch2jax
for usestorch.autograd
to compute gradients internally. To gettorch2jax
working, I had to delete the linewith torch.no_grad()
inapi.py
(2 occurrences) and ingradients.py
(1 occurrence).Would it be possible to make the use of
with torch.no_grad()
optional at the user-level? I suppose the only reason to use it here is for speedup during output shape inference. Maybe this could be something likeuses_gradients_internally
with default valueFalse
.