juliuskunze / jaxnet

Concise deep learning for JAX
Apache License 2.0
184 stars 14 forks source link

RNN training fails #2

Closed juliuskunze closed 4 years ago

juliuskunze commented 5 years ago

Optimizing an RNN fails with NotImplementedError: Forward-mode differentiation rule for 'while' not implemented

This can be verified by running the OCR/RNN after removing the break statement.

scan from jax allows training already, but we have to use a custom version of scan to allow parameterization. Adding a custom differentiation rule for _scan_apply` should fix this.