FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 209 forks source link

combining loops and addition causes a dimension mismatch #1261

Open samsartor opened 2 years ago

samsartor commented 2 years ago

Here is the minimal reproducer I came up with

import Flux
import Zygote
using Functors

struct Test
    a
    b
end

@functor Test

function (m::Test)(x)
    a = x
    for f=m.a
        a = f(a)
    end
    b = x
    for f=m.b
        b = f(b)
    end
    a + b
end

t = Test([Flux.Dense(10=>5)], [Flux.Dense(10=>5)])
x = rand(10)
Zygote.gradient(() -> sum(t(x)), Flux.params(t))

The error on for f=m.b is

ERROR: LoadError: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 5 and 10")
Stacktrace:
  [1] _bcs1
    @ ./broadcast.jl:516 [inlined]
  [2] _bcs
    @ ./broadcast.jl:510 [inlined]
  [3] broadcast_shape
    @ ./broadcast.jl:504 [inlined]
  [4] combine_axes
    @ ./broadcast.jl:499 [inlined]
  [5] instantiate
    @ ./broadcast.jl:281 [inlined]
  [6] materialize
    @ ./broadcast.jl:860 [inlined]
  [7] accum(x::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, ys::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/lib/lib.jl:25
  [8] Pullback
    @ repro.jl:18 [inlined]
  [9] (::typeof(∂(λ)))(Δ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [10] Pullback
    @ repro.jl:26 [inlined]
 [11] (::typeof(∂(#3)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#97#98"{Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, typeof(∂(#3)), Zygote.Context})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:357
 [13] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
    @ Zygote ~/.julia/packages/Zygote/IoW2g/src/compiler/interface.jl:76
 [14] top-level scope
    @ repro.jl:26
in expression starting at repro.jl:26

Even though all the dimensions actually match up correctly. Running t(x) is just fine.

The error goes away with any one of the following:

This is with Zygote version 0.6.41

mcabbott commented 2 years ago

Even a = identity(x) is enough to stop this. It seems to sometimes get confused that assignment does not permanently identify variables.

Similar to #1236 and https://github.com/FluxML/Zygote.jl/issues/1198 perhaps.

radudiaconu0 commented 4 months ago

i have the same error and i dont know how to solve it