PumasAI / SimpleChains.jl

Simple chains
MIT License
234 stars 13 forks source link

Zygote derivative wrt chain input `x` returns StackOverflowError #108

Open axsk opened 1 year ago

axsk commented 1 year ago
using SimpleChains, Zygote
sc = SimpleChain(
                static(2),
                Activation(x -> x.^3),
                TurboDense{true}(tanh, static(50)),
                TurboDense{true}(identity, static(2))
            )

p_nn = SimpleChains.init_params(sc)

Zygote.gradient([2.,2.]) do x 
    sc(x, p_nn) |> sum
end

returns

ERROR: StackOverflowError:
Stacktrace:
 [1] pullback!(pg::Ptr{Float32}, td::TurboDense{true, Static.StaticInt{2}, typeof(identity)}, C̄::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, B::StrideArraysCore.PtrArray{Tuple{Static.StaticInt{50}}, (true,), Float32, 1, 1, 0, (1,), Tuple{Static.StaticInt{4}}, Tuple{Static.StaticInt{1}}}, p::Ptr{Float32}, pu::Ptr{UInt8}, pu2::Ptr{UInt8}) (repeats 79984 times)
   @ SimpleChains ~/.julia/packages/SimpleChains/HhLUa/src/dense.jl:904

using Julia 1.8.1 and SimpleChains 0.3.1

axsk commented 1 year ago

The problem seems to be with sum and not the derivative wrt. x

sum(abs2, sc(x, p_nn) instead of the simple sum works

However exchanging the gradient to be wrt. the parameters leads to the same StackOverflow

Zygote.gradient(p_nn) do p 
    sum(sc(rand(2), p))
end
chriselrod commented 1 year ago

The solution here will be to add a few methods that support Fill arrays.