JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
436 stars 89 forks source link

No rules for `typed_hvcat` #743

Open MasonProtter opened 1 year ago

MasonProtter commented 1 year ago

This causes problems with Zygote:

julia> using Zygote

julia> gradient(x -> sum([x x]), pi/2)
(2.0,)

julia> gradient(x -> sum(Float32[x x]), pi/2)
ERROR: Mutating arrays is not supported -- called setindex!(Matrix{Float32}, ...)
...
  [7] typed_hcat
    @ Zygote ./abstractarray.jl:1610 [inlined]

Ideally we'd teach ChainRules how to handled this statement without going through to the implementation.

mcabbott commented 1 year ago

See also #695