Closed pbrehmer closed 6 months ago
Simply from looking at the stack trace (I have not tested it myself), it seems as if the pullback function of tensortrace!
gets a Tuple{Float64}
as adjoint input ΔC
, instead of an Array{Float64,0}
. This is strange, as the pullback rule of tensorscalar
does transform a scalar back into a Array{Float64,0}
value. Is Zygote trying to be smart here, by replacing this zero-dimensional array with a Tuple in between these two steps? Or is it somehow simply sidestepping the pullback rule of tensorscalar
?
I just quickly checked: The pullback rule of tensorscalar
is being called and it does indeed output an Array{Float64,0}
. Then ΔC
somehow gets converted to a Tuple{Float64}
in tensortrace!
. Is this perhaps also related to the @thunk
and unthunk
of ChainRulesCore?
What's also weird is that pC
is just an empty tuple ((), ())
such that numind(pC)
in tensortrace!
errors as well, when I force ΔC
to be of the correct type. So I don't quite understand what is happening under the hood of Zygote here.
So apparently it's not Zygote's fault, but ours.
julia> scale(fill(0.5), 0.3)
(0.15,)
scale
is the culprit here that changes the Array{Float64,0}
into a Tuple{Float64}
. Will investigate further, this should very clearly not happen!
Addendum:
The reason seems to be with broadcasting; broadcasting is used to implement scale(::AbstractArray, ::Number)
, but it treats zero-dimensional arrays specially and does not preserve them.
Ok, so this fixes it on my side: https://github.com/Jutho/VectorInterface.jl/commit/d4d11298ecc6feae0a33588563709830798ac87f
As soon as the tests turn green, I will tag a new release.
Great, thanks for the fast fix!
Fixed and tagged in v4.1
As I was playing around with the new AD capabilities, I found that taking gradients (using Zygote) of operations involving array traces leads to errors in the
tensortrace!
reverse-rule. For example, when taking the gradient of a matrix trace, I get the following error message:Interestingly, doing the same using
TensorMap
s from TensorKit will not error:All of this was run on the latest versions of TensorOperations, TensorKit and Zygote:
Also more complicated contractions resulting in a scalar seem to have the same problem when contracting
Array
s.