pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Unable to compute derivatives due to calling .item() #1066

Closed elientumba2019 closed 1 year ago

elientumba2019 commented 1 year ago

Hello, i am getting the error below whenever i try to compute the jacobian of my network.

RuntimeError: vmap: It looks like you're either (1) calling .item() on a Tensor or (2) attempting to use a Tensor in some data-dependent control flow or (3) encountering this error in PyTorch internals. For (1): we don't support vmap over calling .item() on a Tensor, please try to rewrite what you're doing with other operations. For (2): If you're doing some control flow instead, we don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . For (3): please file an issue.

the error can be traced back to the line below.

weights = interpolation_weights.prod(-1)

Is there a way around this ?

Thank you .

samdow commented 1 year ago

Hi @elientumba2019, thanks for the issue! Are you able to share more of the code with us? In testing locally, I'm not able to reproduce this using toy data with a prod call so want to check out what's happening. Thanks!

zou3519 commented 1 year ago

@elientumba2019 if you could also let us know what version of PyTorch you're using, that would be great. This problem was likely fixed in PyTorch 1.13, so upgrading to the latest PyTorch may make it go away.

elientumba2019 commented 1 year ago

@zou3519 @samdow Thank you for the suggestion. I have upgraded my PyTorch version and the problem seems to have been fixed.