FluxML / Tracker.jl

Flux's ex AD
Other
51 stars 37 forks source link

StackOverflowError for long recurrences #49

Closed eamartin closed 4 years ago

eamartin commented 5 years ago

Consider the following code

import Random
import Flux                                                                                                                                                                                                                                                                                                                                                
import Flux.Tracker                                                                                                                                                                                                                                                                                                                                        

Random.seed!(2018)

N = (1 << 10)                                                                                                                                                                                                                                                                                                                                              
x = randn(N)                                                                                                                                                                                                                                                                                                                                               

function compute_ema(half_life, init_state)                                                                                                                                                                                                                                                                                                                
    alpha = log(2) / half_life                                                                                                                                                                                                                                                                                                                             
    beta = exp(-alpha)                                                                                                                                                                                                                                                                                                                                     

    h = init_state                                                                                                                                                                                                                                                                                                                                         
    for i = 1:N                                                                                                                                                                                                                                                                                                                                            
        h = beta * h + alpha * x[i]                                                                                                                                                                                                                                                                                                                        
    end                                                                                                                                                                                                                                                                                                                                                    
    h                                                                                                                                                                                                                                                                                                                                                      
end                                                                                                                                                                                                                                                                                                                                                        

grad = Tracker.gradient(compute_ema, 1000.0, 0.0)                                                                                                                                                                                                                                                                                                          
println(grad)  

This outputs

(-1.736128260716541e-5 (tracked), 0.4917510370131168 (tracked))

However, if I increase N to (1 << 13) then it has a StackOverflowError

ERROR: LoadError: StackOverflowError:
Stacktrace:
 [1] accum!(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:100
 [2] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:115
 [3] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:106
 [4] foreach at ./abstractarray.jl:1836 [inlined]
 [5] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##250#253")){Flux.Tracker.TrackedReal{Float64},Flux.Tracker.TrackedReal{Float64}},Tuple{Flux.Tracker.Tracked{Float64},Flux.Tracker.Tracked{Float64}}}, ::Flux.Tracker.TrackedReal{Float64}) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:106
 [6] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:118
 ... (the last 4 lines are repeated 1 more time)
 [11] (::getfield(Flux.Tracker, Symbol("##4#5")){Flux.Tracker.Grads})(::Flux.Tracker.Tracked{Float64}, ::Flux.Tracker.TrackedReal{Float64}) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:106
 [12] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##250#253")){Flux.Tracker.TrackedReal{Float64},Flux.Tracker.TrackedReal{Float64}},Tuple{Flux.Tracker.Tracked{Float64},Flux.Tracker.Tracked{Float64}}}, ::Flux.Tracker.TrackedReal{Float64}) at ./abstractarray.jl:1836
 ... (the last 7 lines are repeated 6859 more times)
 [48026] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Int64) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:118
 [48027] FluxML/Flux.jl#4 at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:106 [inlined]
 [48028] foreach at ./abstractarray.jl:1836 [inlined]
 [48029] back_(::Flux.Tracker.Grads, ::Flux.Tracker.Call{getfield(Flux.Tracker, Symbol("##226#229")),Tuple{Flux.Tracker.Tracked{Float64},Flux.Tracker.Tracked{Float64}}}, ::Int64) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:106
 [48030] back(::Flux.Tracker.Grads, ::Flux.Tracker.Tracked{Float64}, ::Int64) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:118
 [48031] FluxML/Flux.jl#6 at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:131 [inlined]
 [48032] (::getfield(Flux.Tracker, Symbol("##9#11")){getfield(Flux.Tracker, Symbol("##6#7")){Flux.Tracker.Params,Flux.Tracker.TrackedReal{Float64}}})(::Int64) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:140
 [48033] gradient(::Function, ::Float64, ::Vararg{Float64,N} where N) at /home/ermartin/.julia/packages/Flux/oN61x/src/tracker/back.jl:152
 [48034] top-level scope at none:0
 [48035] include at ./boot.jl:317 [inlined]
 [48036] include_relative(::Module, ::String) at ./loading.jl:1044
 [48037] include(::Module, ::String) at ./sysimg.jl:29
 [48038] exec_options(::Base.JLOptions) at ./client.jl:231
in expression starting at /home/ermartin/test.jl:22

My guess is that this happens because of the recursive nature of the backwards algorithm. I don't know how possible it is to modify backprop to use an iterative algorithm instead of recursive. I also don't know if another autodiff system such as Zygote would handle this better.

Any suggestions for workarounds or plans to eventually fix this issue? Support for fast auto-diff over algorithms on long sequences would be valuable for time series analysis.

opus111 commented 5 years ago

I am hitting the same issue. I was about to post, but you beat me to it :-)

In my case, I succeed with a small amount of training data, but get stack-overflow with more examples. In my app all sequences are the same length (8)

Initially, I thought there was some 'bad' training example in my data, but all my data can be successfully used if broken into small enough chunks. In my case training data of 512 examples works, but 768 causes a stack-overflow.

I am using a GRU

MikeInnes commented 5 years ago

Yes, I think you're spot on about this coming down to recursion. It should be quite straightforward to switch this to something more iterative / keep track of the stack manually to avoid this. This specific issue won't affect Zygote, though it's worth noting that you'll still have memory usage linear in the number of loop iterations, so eventually you need to use something like checkpointing here either way.

MikeInnes commented 5 years ago

I had a go at this and you can try it with add Flux#iterate, which should make this issue go away. (But it isn't quite polished, not all tests pass yet).

ianwilliamson commented 5 years ago

Running into this same issue with a large number of time samples.

ianwilliamson commented 4 years ago

Wanted to follow up on this issue to see what the longer term solution to this will be. Would Zygote be a good way to get around this issue?

When I made my previous post back in Feb, I was looking into using Flux for tracking gradients of solutions to PDEs in the time domain (the scalar wave equation). Due to getting stack overflows in Flux for moderate- to long-term time sequences, as well as Zygote's immaturity (at the time), I ended up just using pytorch: https://github.com/fancompute/wavetorch. Of course, the optimal approach in either framework is probably to manually define the adjoint recurrence relation and implement primitives (to use the HIPS autograd terminology) that hook into AD. We ended up doing this for our pytorch implementation and it significantly reduced memory usage during backprop. However, at a minimum (without check pointing) you still need to store the forward solution at every time step. I'm not sure about the detailed differences between the Flux and pytorch AD implementations, but I fear that Flux won't be able to handle this.

What are your thoughts? I would really like to use Julia for some of my future projects in this area and Flux/Zygote seem like really nice packages. Perhaps my use case is too much of an outlier compared to typical ML workloads. Thanks again for your work on this great package!

MikeInnes commented 4 years ago

Zygote is pretty good at this sort of thing (and is now the default in Flux, on master). There's a simple implementation of checkpointing built-in and we've also used it a bit with DiffEqFlux, which needs a lot of similar tricks. It also is not prone to stack overflow issues, so that gets rid of one class of issues straight away.