Open jondeuce opened 2 months ago
Somehow the gradient of trace needs to be a Diagonal{..., CuArray{...
, made by either by calling similar
, or by having a special rule for CuArray.
Rule in CR makes always Diagonal{..., Array{...
, which won't be any better than the present state:
Rule here making a Diagonal{..., Fill{...
, seems like probably a bad idea, premature optimisation:
There's a similar issue for sum
, where https://github.com/FluxML/Zygote.jl/pull/1453/files wants to remove the rule which uses FillArrays to give this:
julia> Zygote.gradient(sum, [1 2; 3 4.])
(Fill(1.0, 2, 2),)
and also remove the special rule for sum(xs::AbstractGPUArray)
which uses similar
. But in that case the CR rules always use similar
.
MWE:
Seems to be hitting this generic
accum
method and falling back to scalar indexing.There's a note here about efficiently implementing the
rrule
forLinearAlgebra.tr
, which returns aFill
wrapped in aDiagonal
, and this seems to cause issues with broadcasting. In fact, here's an even smaller MWE:Package and version info: