Optimizing an RNN fails with
NotImplementedError: Forward-mode differentiation rule for 'while' not implemented
This can be verified by running the OCR/RNN after removing the break statement.
scan from jax allows training already, but we have to use a custom version of scan to allow parameterization. Adding a custom differentiation rule for _scan_apply` should fix this.
Optimizing an RNN fails with
NotImplementedError: Forward-mode differentiation rule for 'while' not implemented
This can be verified by running the OCR/RNN after removing the break statement.
scan
from jax allows training already, but we have to use a custom version ofscan
to allow parameterization. Adding a custom differentiation rule for_scan_apply
` should fix this.