Open liuyxpp opened 1 year ago
If you're ok with the initial state being non-trainable, then using one of the functions under https://juliadiff.org/ChainRulesCore.jl/stable/api.html#Ignoring-gradients on the reset!
line should work. e.g. @ignore_derivatives Flux.reset!(model)
. Moving the call to reset!
outside of the loss function would also do the trick.
Ah, thanks! Can you explain more why does this fail for explicit mode but not implicit mode?
BTW, if I have extra data to train the initial state for each time sequence, how should I do that?
I'm not sure why it fails. The RNN API is a weird one because it uses some of the implicit mode machinery even when you use explicit mode.
if I have extra data to train the initial state for each time sequence, how should I do that?
If you want to have separate initial states for each sample like you mentioned in https://github.com/FluxML/Flux.jl/issues/2185#issuecomment-1736563421, the best bet would be to use the underlying RNN cell API (e.g. RNN
-> RNNCell
) and write your own loop over the timesteps. It'll be more manual work than using the Recur
-based API, but it should just work and also avoid the MethodError shown above.
Got that and I will report back once I figure it out. Many thanks!
I am trying to reproduce the tutorial A Basic RNN using Flux.jl v0.14.6. Using the old Flux API as in the tutorial, the model can be successfully trained. The code is
However, refractor the above code to use the new explicit API, Zygote complains:
The code is as follows:
Julia version info