FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 209 forks source link

Gradient of scalar function of gradient giving mutating array error #1497

Open mrazomej opened 7 months ago

mrazomej commented 7 months ago

I am trying to implement a particular flavor of a variational autoencoder---the Riemannian Hamiltonian VAE from this publication---whose loss function involves two steps that I am reproducing in my MWE:

  1. Restructuring the output of a neural network as a lower triangular matrix to then compute some operations with this matrix, such as the log-determinant as in the MWE.
  2. Computing a gradient with respect to the neural network inputs and performing some operation with the resulting gradient to recover a scalar---a simple sum in my MWE.

I have read as much as I have been able on the issues, and I know that gradients of functions of gradients with Zygote are a pain. But I want to know if there is a way around this or if this is indeed something that cannot be done with Zygote.

Here is my MWE:

import Flux
import Zygote
import LinearAlgebra

# Define dummy model
model = Flux.Chain(
    Flux.Dense(4, 3, Flux.σ),
    Flux.Dense(3, 6, Flux.identity)
)

# Define function to turn output into lower triangular matrix
function lower_triangular(x::Vector{Float32})
    # Initialize lower triangular matrix as Zygoate.Buffer
    L_buff = Zygote.bufferfrom(
        zeros(Float32, 3, 3)
    )
    # Populate lower triangular matrix
    L_buff[1, 1] = x[1]
    L_buff[2, 1] = x[2]
    L_buff[2, 2] = x[3]
    L_buff[3, 1] = x[4]
    L_buff[3, 2] = x[5]
    L_buff[3, 3] = x[6]

    return copy(L_buff)
end # lower_triangular

function loss(model::Flux.Chain, input::Vector{Float32})
    # Define internal function 
    function _logdet(input::Vector{Float32})
        # Forward pass input through model
        x = model(input)
        # Convert output to lower triangular matrix
        L = lower_triangular(x)
        # Compute log determinant of lower triangular matrix
        return LinearAlgebra.logdet(L * transpose(L))
    end # _logdet

    # Compute gradient of loss function with respect to input and sum values.
    return sum(first(Zygote.gradient(input -> _logdet(input), input)))
end # loss

# Define fixed input
x = rand(Float32, 4)
# Compute gradient of loss function with respect to model parameters
loss(model, x)

# Compute gradient of loss function with respect to model parameters
Zygote.gradient(model -> loss(model, x), model)

The output of the last line gives the following error referring to Zygote's inability to work with mutating arrays.

ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:70
  [3] (::Zygote.var"#539#540"{Matrix{Float32}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:82
  [4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Matrix{Float32}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] (::Zygote.var"#291#292"{Tuple{Tuple{…}, Tuple{…}}, Zygote.var"#2623#back#541"{Zygote.var"#539#540"{…}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206
  [6] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.var"#2623#back#541"{…}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [7] (::Zygote.var"#1145#1147"{Zygote.Context{false}, Zygote.Buffer{Float32, Matrix{…}}, Tuple{Int64, Int64}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/buffer.jl:23 [inlined]
  [8] (::Zygote.Pullback{Tuple{Zygote.var"#1145#1147"{…}, Nothing}, Any})(Δ::Tuple{Nothing, Float32, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
  [9] #3702#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Float32, Nothing, Nothing})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [11] lower_triangular
    @ ~/git/AutoEncode/experiments/gpu_rhvae_jointlogencoder_jointlogdecoder_3Dpeak/mwe.jl:18 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Vector{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [13] _logdet
    @ ~/git/AutoEncode/experiments/gpu_rhvae_jointlogencoder_jointlogdecoder_3Dpeak/mwe.jl:34 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [15] #46
    @ ~/git/AutoEncode/experiments/gpu_rhvae_jointlogencoder_jointlogdecoder_3Dpeak/mwe.jl:40 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, FillArrays.Fill{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [17] #75
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{FillArrays.Fill{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [19] gradient
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148 [inlined]
 [20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{FillArrays.Fill{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [21] loss
    @ ~/git/AutoEncode/experiments/gpu_rhvae_jointlogencoder_jointlogdecoder_3Dpeak/mwe.jl:40 [inlined]
 [22] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [23] #49
    @ ~/git/AutoEncode/experiments/gpu_rhvae_jointlogencoder_jointlogdecoder_3Dpeak/mwe.jl:49 [inlined]
 [24] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [26] gradient(f::Function, args::Flux.Chain{Tuple{Flux.Dense{…}, Flux.Dense{…}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
 [27] top-level scope
    @ ~/git/AutoEncode/experiments/gpu_rhvae_jointlogencoder_jointlogdecoder_3Dpeak/mwe.jl:49
Some type information was truncated. Use `show(err)` to see complete types.

Thank you in advance for your help.

mcabbott commented 7 months ago

Can reproduce, and the following nested gradient gives the error with less code, no Flux:

julia> Zygote.gradient(x -> sum(abs2, lower_triangular(x)), collect(Float32, 1:6))
(Float32[2.0, 4.0, 6.0, 8.0, 10.0, 12.0],)

julia> Zygote.gradient(collect(Float32, 1:6)) do x1
         Zygote.gradient(x -> sum(abs2, lower_triangular(x)), x1)[1] |> sum
       end
ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
# I believe the same stack trace as above

Trying to avoid Buffer, this gives another error (perhaps from hvcat_pullback not being second-order friendly?):

julia> function lower_triangular(x::Vector{Float32})
         [x[1] 0f0 0f0
          x[2] x[3] 0f0
          x[4] x[5] x[6]]
       end
lower_triangular (generic function with 1 method)

This uses simpler rules:

julia> Zygote.refresh()

julia> function lower_triangular(x::Vector{Float32})
         reshape([x[1], x[2], x[4], 0f0, x[3], x[5], 0f0, 0f0, x[6]], 3,3)
       end
lower_triangular (generic function with 1 method)

julia> Zygote.gradient(x -> sum(abs2, lower_triangular(x)), collect(Float32, 1:6))
(Float32[2.0, 4.0, 6.0, 8.0, 10.0, 12.0],)

julia> Zygote.gradient(collect(Float32, 1:6)) do x1
         Zygote.gradient(x -> sum(abs2, lower_triangular(x)), x1)[1] |> sum
       end
(Float32[2.0, 2.0, 2.0, 2.0, 2.0, 2.0],)

julia> Zygote.gradient(model -> loss(model, x), model)  # with above Flux model
((layers = ((weight = Float32[3.3456602 3.5011783 4.486716 4.070404; -24.406696 -25.78394 -34.5117 -30.824911; 38.69364 40.78779 54.058647 48.45276], bias = Float32[3.7910285, -33.572727, 51.04849], σ = nothing), (weight = Float32[-1.4263158 -1.3037528 0.7782692; 0.0008880339 0.00072802976 0.00042506802; … ; -0.0020078237 -0.0016477455 -0.0009427298; -237.13728 -195.81625 -98.22924], bias = Float32[-0.40942383, 0.0011825562, -1.3378944, -0.0043182373, -0.0026550293, -300.20334], σ = nothing)),),)
mcabbott commented 7 months ago

FWIW, here's an attempt at a minimal example for the hvcat error. I'm not sure whether the bug is in ChainRules or here, e.g. in ∇map or something:

julia> Zygote.gradient(2.0) do x
         Zygote.gradient(y -> y^3, x)[1]
       end
(12.0,)

julia> Zygote.gradient(2.0) do x
         Zygote.gradient(y -> [y, y][1]^3, x)[1]
       end
(12.0,)

julia> Zygote.gradient(2.0) do x
         Zygote.gradient(y -> [y y y][1]^3, x)[1]
       end
ERROR: Compiling Tuple{ChainRules.var"#1379#1384"{ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, Tuple{Int64, Int64}, Matrix{Float64}}}: ArgumentError: array must be non-empty
Stacktrace:
  [1] macro expansion
    @ ./compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context{false}, ::ChainRules.var"#1379#1384"{ChainRulesCore.ProjectTo{…}, Tuple{…}, Matrix{…}})
    @ Zygote ./compiler/interface2.jl:81
  [3] unthunk
    @ ~/.julia/packages/ChainRulesCore/UrpQe/src/tangent_types/thunks.jl:204 [inlined]
  [4] unthunk
    @ ~/.julia/packages/ChainRulesCore/UrpQe/src/tangent_types/thunks.jl:237 [inlined]
  [5] _pullback(ctx::Zygote.Context{…}, f::typeof(ChainRulesCore.unthunk), args::ChainRulesCore.InplaceableThunk{…})
    @ Zygote ./compiler/interface2.jl:0
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:110 [inlined]
  [7] (::Zygote.var"#662#666"{…})(args::ChainRulesCore.InplaceableThunk{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:187
  [8] map
    @ ./tuple.jl:282 [inlined]
  [9] map
    @ ./tuple.jl:283 [inlined]
 [10] ∇map(cx::Zygote.Context{…}, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:187
 [11] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:213 [inlined]
 [12] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [13] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:111 [inlined]
 [14] _pullback(ctx::Zygote.Context{…}, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{…})
    @ Zygote ./compiler/interface2.jl:0
 [15] ZBack
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
 [16] #89
    @ ./REPL[45]:2 [inlined]
 [17] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [18] #75
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91 [inlined]
 [19] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [20] gradient
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148 [inlined]
 [21] _pullback(::Zygote.Context{false}, ::typeof(Zygote.gradient), ::var"#89#91", ::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [22] #88
    @ ./REPL[45]:2 [inlined]
 [23] _pullback(ctx::Zygote.Context{false}, f::var"#88#90", args::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [24] pullback(f::Function, cx::Zygote.Context{false}, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
 [25] pullback
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined]
 [26] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
 [27] top-level scope
    @ REPL[45]:1
Some type information was truncated. Use `show(err)` to see complete types.

julia> ForwardDiff.derivative(2.0) do x
         Zygote.gradient(y -> [y y y][1]^3, x)[1]
       end
12.0

(@v1.11) pkg> st Zygote ChainRules
Status `~/.julia/environments/v1.11/Project.toml`
⌃ [082447d4] ChainRules v1.58.1
  [e88e6eb3] Zygote v0.6.69
Info Packages marked with ⌃ have new versions available and may be upgradable.
mcabbott commented 7 months ago

And here are two attempts at a minimal example of using Buffer at 2nd order. I think the first is some other bug, but the second appears to say that Buffer does not support this at all:

julia> function buf_id(x::Real)
        b = Zygote.Buffer(zeros(1))
        b[1] = x
        sum(copy(b))
       end;

julia> buf_id(pi)
3.141592653589793

julia> Zygote.gradient(x -> buf_id(x)^3, 2.0)
(12.0,)

julia> Zygote.gradient(x -> Zygote.gradient(y -> buf_id(y)^3, x)[1], 2.0)
ERROR: MethodError: _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::typeof(identity), ::FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}) is ambiguous.

Candidates:
  _pullback(__context__::ZygoteRules.AContext, var"586"::typeof(Base.Broadcast.broadcasted), var"587"::typeof(identity), x::Union{AbstractArray{<:T}, T} where T<:Number)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:66
  _pullback(__context__::ZygoteRules.AContext, var"558"::typeof(Base.Broadcast.broadcasted), op, r::FillArrays.AbstractFill{<:Real})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:66

Possible fix, define
  _pullback(::ZygoteRules.AContext, ::typeof(Base.Broadcast.broadcasted), ::typeof(identity), ::FillArrays.AbstractFill{…})

Stacktrace:
  [1] copy_sensitivity
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/buffer.jl:54 [inlined]
  [2] _pullback(ctx::Zygote.Context{…}, f::Zygote.var"#copy_sensitivity#1161"{…}, args::FillArrays.Fill{…})
    @ Zygote ./compiler/interface2.jl:0
  [3] #3732#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
  [4] _pullback(ctx::Zygote.Context{…}, f::Zygote.var"#3732#back#1162"{…}, args::FillArrays.Fill{…})
    @ Zygote ./compiler/interface2.jl:0
  [5] buf_id
    @ ./REPL[63]:4 [inlined]
  [6] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
    @ Zygote ./compiler/interface2.jl:0
  [7] #117
    @ ./REPL[66]:1 [inlined]
  [8] _pullback(ctx::Zygote.Context{false}, f::Zygote.Pullback{Tuple{…}, Tuple{…}}, args::Float64)
    @ Zygote ./compiler/interface2.jl:0
  [9] #75
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}}, args::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [11] gradient
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::typeof(Zygote.gradient), ::var"#117#119", ::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [13] #116
    @ ./REPL[66]:1 [inlined]
 [14] _pullback(ctx::Zygote.Context{false}, f::var"#116#118", args::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [15] pullback(f::Function, cx::Zygote.Context{false}, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
 [16] pullback
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined]
 [17] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
 [18] top-level scope
    @ REPL[66]:1
Some type information was truncated. Use `show(err)` to see complete types.

julia> function buf_id2(x::Real)
        b = Zygote.Buffer(zeros(1))
        b[1] = x
        only(copy(b))
       end;

julia> Zygote.gradient(x -> buf_id2(x)^3, 2.0)
(12.0,)

julia> Zygote.gradient(x -> Zygote.gradient(y -> buf_id2(y)^3, x)[1], 2.0)
ERROR: Mutating arrays is not supported -- called setindex!(Vector{Float64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _throw_mutation_error(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:70
  [3] (::Zygote.var"#539#540"{Vector{Float64}})(::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/array.jl:82
  [4] (::Zygote.var"#2623#back#541"{Zygote.var"#539#540"{Vector{Float64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [5] (::Zygote.var"#291#292"{Tuple{Tuple{…}, Tuple{…}}, Zygote.var"#2623#back#541"{Zygote.var"#539#540"{…}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:206
  [6] (::Zygote.var"#2169#back#293"{Zygote.var"#291#292"{Tuple{…}, Zygote.var"#2623#back#541"{…}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
  [7] #1145
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/buffer.jl:23 [inlined]
  [8] (::Zygote.Pullback{Tuple{Zygote.var"#1145#1147"{…}, Nothing}, Any})(Δ::Tuple{Nothing, Float64, Nothing})
    @ Zygote ./compiler/interface2.jl:0
  [9] #3702#back
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72 [inlined]
 [10] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Nothing, Float64, Nothing})
    @ Zygote ./compiler/interface2.jl:0
 [11] buf_id2
    @ ./REPL[67]:3 [inlined]
 [12] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Float64})
    @ Zygote ./compiler/interface2.jl:0
 [13] #123
    @ ./REPL[69]:1 [inlined]
 [14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Nothing, Float64})
    @ Zygote ./compiler/interface2.jl:0
 [15] #75
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91 [inlined]
 [16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float64})
    @ Zygote ./compiler/interface2.jl:0
 [17] gradient
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148 [inlined]
 [18] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Tuple{Float64})
    @ Zygote ./compiler/interface2.jl:0
 [19] #122
    @ ./REPL[69]:1 [inlined]
 [20] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ./compiler/interface2.jl:0
 [21] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
 [22] gradient(f::Function, args::Float64)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
 [23] top-level scope
    @ REPL[69]:1
Some type information was truncated. Use `show(err)` to see complete types.
mrazomej commented 7 months ago

Thank you very much for your quick response, @mcabbott. Indeed, removing the use of Zygote.Buffer removed the mutating array error. But in my full implementation, I now get the same error you got when trying to reproduce the hvcat error.

This is way way out of my comfort zone. I don't even know where to start looking at how to fix this.