pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

Forward auto differentiation not implemented for torch softplus nonlinearity #861

Open lrast opened 2 years ago

lrast commented 2 years ago

Found using functorch.hessian and in a model with a torch.nn.Softplus() nonlinearity. Error message:

NotImplementedError: Trying to use forward AD with softplus_backward that does not support it.

Writing the nonlinearity by hand fixed the issue. Reporting as requested

zou3519 commented 2 years ago

Thanks for reporting, @lrast. I assume that you're using PyTorch 1.11 and functorch 0.1.*.

This problem should be fixed in the pytorch nightly + functorch main branch. If you're interested in trying that out, please see the instructions over at https://github.com/pytorch/functorch#installing-functorch-from-source .