Closed eamartin closed 4 years ago
I am hitting the same issue. I was about to post, but you beat me to it :-)
In my case, I succeed with a small amount of training data, but get stack-overflow with more examples. In my app all sequences are the same length (8)
Initially, I thought there was some 'bad' training example in my data, but all my data can be successfully used if broken into small enough chunks. In my case training data of 512 examples works, but 768 causes a stack-overflow.
I am using a GRU
Yes, I think you're spot on about this coming down to recursion. It should be quite straightforward to switch this to something more iterative / keep track of the stack manually to avoid this. This specific issue won't affect Zygote, though it's worth noting that you'll still have memory usage linear in the number of loop iterations, so eventually you need to use something like checkpointing here either way.
I had a go at this and you can try it with add Flux#iterate
, which should make this issue go away. (But it isn't quite polished, not all tests pass yet).
Running into this same issue with a large number of time samples.
Wanted to follow up on this issue to see what the longer term solution to this will be. Would Zygote be a good way to get around this issue?
When I made my previous post back in Feb, I was looking into using Flux for tracking gradients of solutions to PDEs in the time domain (the scalar wave equation). Due to getting stack overflows in Flux for moderate- to long-term time sequences, as well as Zygote's immaturity (at the time), I ended up just using pytorch: https://github.com/fancompute/wavetorch. Of course, the optimal approach in either framework is probably to manually define the adjoint recurrence relation and implement primitives (to use the HIPS autograd terminology) that hook into AD. We ended up doing this for our pytorch implementation and it significantly reduced memory usage during backprop. However, at a minimum (without check pointing) you still need to store the forward solution at every time step. I'm not sure about the detailed differences between the Flux and pytorch AD implementations, but I fear that Flux won't be able to handle this.
What are your thoughts? I would really like to use Julia for some of my future projects in this area and Flux/Zygote seem like really nice packages. Perhaps my use case is too much of an outlier compared to typical ML workloads. Thanks again for your work on this great package!
Zygote is pretty good at this sort of thing (and is now the default in Flux, on master). There's a simple implementation of checkpointing built-in and we've also used it a bit with DiffEqFlux, which needs a lot of similar tricks. It also is not prone to stack overflow issues, so that gets rid of one class of issues straight away.
Consider the following code
This outputs
However, if I increase
N
to(1 << 13)
then it has aStackOverflowError
My guess is that this happens because of the recursive nature of the backwards algorithm. I don't know how possible it is to modify backprop to use an iterative algorithm instead of recursive. I also don't know if another autodiff system such as Zygote would handle this better.
Any suggestions for workarounds or plans to eventually fix this issue? Support for fast auto-diff over algorithms on long sequences would be valuable for time series analysis.