FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.48k stars 604 forks source link

Make RNNs blocked (and maybe fixing gradients along the way) #2258

Open mkschleg opened 1 year ago

mkschleg commented 1 year ago

Motivation and description

Given #2185 and other issues caused by the current mutability of the recur interface, we should move to a more standard blocked (i.e. 3D for simple RNN) interface. This has the benefits of: 1) cleaning the recurrent interface so it is more easily used by people coming from other packages, 2) more easily enable workflows using convRNNs, and 3) potentially enable some optimizations we can handle on the Flux side (see Lux's Recurrence return_sequence=true vs false)

I have not tested how we might fix the gradients by moving to this restricted interface. But if we decide to remove the statefulness (see below) we can fix gradients as seen in https://github.com/FluxML/Fluxperimental.jl/pull/7.

Possible Implementation

I see two ways we can do this change, one which is a wider change of the Flux chain interface and another which tries to only fix Recur. In either case, the implementation would assume the final dimension of your multi-dimensional array is the time index. For a simple RNN it would assume the dimensions of the incoming array as: Features x Batch x Time. It will produce an error if a 2d array or 1d array is passed to recur, to avoid ambiguities.

One possible implementation is to go ahead and do the full change over to removing state from the network generally. See https://github.com/FluxML/Fluxperimental.jl/pull/7. This would overhaul large parts of the interface into chain, and could be targeted at 0.14. See the implementation done in the above PR and https://github.com/FluxML/Fluxperimental.jl/pull/5 for details.

The second possible approach is to just first remove the loop over timesteps interface and replace with the 3d interface. This initial change restricts the interface to be 3d, but I haven't tested how we could fix gradients while maintaining mutability and statefulness in Recur. The interface/impl would likely look much like: https://github.com/FluxML/Flux.jl/blob/c9c262db1c851cc612389f86854b1987083aab25/src/layers/recurrent.jl#L184-L188

ToucheSir commented 1 year ago

On your second approach, how about emulating what PyTorch does with immutable struct wrappers over the RNN cell types? Say const LSTM = NewRecur{LSTMCell}. This API would only accept 3D+ sequences and return the hidden state. We could make passing the hidden state optional with a signature like (::NewRecur)(x, h = <default value>). Much like PyTorch and TensorFlow, we could add a type parameters to toggle whether the RNN is bidirectional and returns a sequence/timestep respectively.

Integration with Chain in this approach would be a bit more work since preceding and following layers would need to be sequence-aware, but in practice I haven't seen many actually taking advantage of being able to create combinations like Chain(Dense(), RNN(), LayerNorm()) to apply non-recurrent layers per-timestep. Just extracting the new hidden state or output from the RNN return value would be a simple matter of adding first or last after it in the chain.

This approach avoids having to deal with state by making the user save it and carry it over themselves. It's not as ergonomic since they'd have to thread said states through larger models, but that's more than doable with layers like Parallel. Going stateless also makes AD easier! The big remaining AD-related issue I see is that differentiating through loops is still slow, but we can address that easier with a 3D sequence-based interface by defining rules for NewRecur. Maybe this rule could be attached to a JAX-like scan function so that people can use it for their own recurrent layers.

mkschleg commented 1 year ago

Good idea. This also seems similar to Haiku's approach as well afaict. I think this could also give us an opportunity to provide an interface for a static unroll vs a dynamic unroll. I think at first we should just do a loop, but there might an opportunity to use a generated function to replace the loop with the unrolled version. But that might be more problematic than its worth depending on how the scan function turns out.

mkschleg commented 1 year ago

I added a first pass of this functionality to Fluxperimental.jl. There are some lingering issues that need to be resolved, but it gives us an idea of what we need to support. I think we should have a conversation on how to solve the issue of returning the carry back to the user. @ToucheSir mentioned using Parallel streams to solve the issue, but I think that would be a pretty horrendous interface to use.

mkschleg commented 1 year ago

Ok. We have merged a potential interface (NewRecur) into Fluxperimental. I think we can make a push to get this into v0.15. I think this is necessary if we want to fully remove the old gradient interface. Thoughts? I can work on a PR (I've finally gotten through my PhD defense, so time is less of an issue).

ToucheSir commented 1 year ago

I like that idea. We should probably figure out how to make the rrule type stable as part of that :P

mkschleg commented 1 year ago

Working PR in #2316.