mcabbott / SliceMap.jl

Same-same but different
Other
31 stars 3 forks source link

`ChainRulesCore.NoTangent` error when updating Zygote 0.6.21 -> 0.6.22 #12

Closed sdobber closed 2 years ago

sdobber commented 3 years ago

The following code defines a part of a neural net that can be differentiated with Zygote 0.6.21. Upgrading to 0.6.22, the error ERROR: MethodError: no method matching JuliennedArrays.Slices(::ChainRulesCore.NoTangent, ::Int64) is thrown.

using Flux
using Flux.Zygote
using SliceMap
using JuliennedArrays

# Define a version of `Recur` that returns the whole hidden state
mutable struct HiddenRecur{T,S}
    cell::T
    state::S
end

HiddenRecur(m) = HiddenRecur(m, m.state0)

function (m::HiddenRecur)(xs...)
    h, y = m.cell(m.state, xs...)
    m.state = h
    return h 
end

# Feed arrays sequentially to recurrent neural nets
mutable struct Seq{T}
    chain::T
    state
end
Seq(chain) = Seq(chain, [0.0f0])
(l::Seq)(x) = l(l.chain.state, l.chain, x)
function (l::Seq)(::Tuple, _, x)
    tuples = map(l.chain, Slices(x, True(), False()))
    l.state = [Align(map(x -> dropdims(x[i], dims=2), tuples), 1) for i in 1:length(l.chain.state)]
    return l.state
end

# Quick gradient
function bw(m, ip)
    gs = gradient((m, x) -> sum(m(x)), m, ip)
end

# Actual Code
inp = rand(Float32, 10, 50)
encoder = Seq(HiddenRecur(Flux.LSTMCell(10, 5)))

encoder(inp)  # forward pass works - returns an array of 2 arrays with the hidden state
bw(x -> encoder(x)[1], inp)  # works in Zygote 0.6.21, errors in 0.6.22
mcabbott commented 3 years ago

This package should really move to use ChainRulesCore.jl, but when I started I remembered that it's slightly bigger than I thought.

However, this also sounds like one more place that ChainRules types are leaking into Zygote. I didn't isolate this one, but with https://github.com/FluxML/Zygote.jl/pull/1104 it seems to at least run without error:

julia> bw(x -> encoder(x)[1], inp)  # works in Zygote 0.6.21, errors in 0.6.22
(nothing, Float32[0.36154643 0.26122904 … 0.10572737 0.10822818; 0.5348399 0.58685553 … 0.29224667 0.16772124; … ; -0.8771012 -0.93331707 … -0.38873446 -0.21984804; -0.12109001 -0.33585894 … -0.09494242 -0.102880865])
sdobber commented 2 years ago

Solved with a newer version of Zygote (or one of the dependencies).