Jutho / TensorOperations.jl

Julia package for tensor contractions and related operations
https://jutho.github.io/TensorOperations.jl/stable/
Other
438 stars 55 forks source link

Taking gradients of traces #154

Closed pbrehmer closed 6 months ago

pbrehmer commented 8 months ago

As I was playing around with the new AD capabilities, I found that taking gradients (using Zygote) of operations involving array traces leads to errors in the tensortrace! reverse-rule. For example, when taking the gradient of a matrix trace, I get the following error message:

julia> using TensorOperations, Zygote

julia> function trace(m)
           @tensor s = m[1, 1]
       end
trace (generic function with 1 method)

julia> m = randn(2, 2)
2×2 Matrix{Float64}:
  0.080359  -1.36932
 -0.384686  -0.952139

julia> gradient(trace, m)
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{}}}})(::Tuple{Float64})

Closest candidates are:
  (::ChainRulesCore.ProjectTo{AbstractArray})(::Union{LinearAlgebra.Adjoint{T, var"#s972"}, LinearAlgebra.Transpose{T, var"#s972"}} where {T, var"#s972"<:(AbstractVector)})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/TGTmr/src/projection.jl:247
  (::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{<:ChainRulesCore.AbstractZero})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/TGTmr/src/projection.jl:244
  (::ChainRulesCore.ProjectTo{AbstractArray})(::AbstractArray{S, M}) where {S, M}
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/TGTmr/src/projection.jl:219
  ...

Stacktrace:
  [1] (::TensorOperationsChainRulesCoreExt.var"#70#77"{Array{Float64, 0}, VectorInterface.Zero, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{}}}}})()
    @ TensorOperationsChainRulesCoreExt ~/.julia/packages/TensorOperations/7VyQe/ext/TensorOperationsChainRulesCoreExt.jl:156
  [2] unthunk
    @ ~/.julia/packages/ChainRulesCore/TGTmr/src/tangent_types/thunks.jl:204 [inlined]
  [3] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:110 [inlined]
  [4] map (repeats 2 times)
    @ ./tuple.jl:276 [inlined]
  [5] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:111 [inlined]
  [6] ZBack
    @ ~/.julia/packages/Zygote/XJ8pP/src/compiler/chainrules.jl:211 [inlined]
  [7] Pullback
    @ ./REPL[3]:2 [inlined]
  [8] (::Zygote.Pullback{Tuple{typeof(trace), Matrix{Float64}}, Tuple{Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorscalar_pullback#45"{Array{Float64, 0}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorfree!_pullback#44"}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#76"{Array{Float64, 0}, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{Int64}, Tuple{Int64}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_add_pullback#23"{Tuple{DataType, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Symbol, Bool}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface2.jl:0
  [9] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(trace), Matrix{Float64}}, Tuple{Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorscalar_pullback#45"{Array{Float64, 0}}}, Zygote.var"#2017#back#204"{typeof(identity)}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensorfree!_pullback#44"}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#pullback#76"{Array{Float64, 0}, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Tuple{Tuple{Int64}, Tuple{Int64}}, Symbol, VectorInterface.One, VectorInterface.Zero, Tuple{}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Number, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}}}}, Zygote.ZBack{TensorOperationsChainRulesCoreExt.var"#tensoralloc_add_pullback#23"{Tuple{DataType, Tuple{Tuple{}, Tuple{}}, Matrix{Float64}, Symbol, Bool}}}, Zygote.Pullback{Tuple{typeof(scalartype), Matrix{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Matrix{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(scalartype), Type{Float64}}, Tuple{}}}}, Zygote.ZBack{ChainRules.var"#typeof_pullback#45"}}}}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:45
 [10] gradient(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/XJ8pP/src/compiler/interface.jl:97
 [11] top-level scope
    @ REPL[5]:1

Interestingly, doing the same using TensorMaps from TensorKit will not error:

julia> using TensorKit

julia> mtensor = Tensor(m, ℝ^2 * ℝ^2)
TensorMap((ℝ^2 ⊗ ℝ^2) ← ProductSpace{CartesianSpace, 0}()):
  0.0803590283106982   -1.3693169393992737
 -0.38468550557063524  -0.9521387115295246

julia> gradient(trace, mtensor)
(TensorMap((ℝ^2 ⊗ ℝ^2) ← ProductSpace{CartesianSpace, 0}()):
 1.0  0.0
 0.0  1.0
,)

All of this was run on the latest versions of TensorOperations, TensorKit and Zygote:

(TOTest) pkg> status
Status `~/repos/TOTest/Project.toml`
  [07d1fe3e] TensorKit v0.12.0
  [6aa20fa7] TensorOperations v4.0.7
  [e88e6eb3] Zygote v0.6.65

Also more complicated contractions resulting in a scalar seem to have the same problem when contracting Arrays.

Jutho commented 8 months ago

Simply from looking at the stack trace (I have not tested it myself), it seems as if the pullback function of tensortrace! gets a Tuple{Float64} as adjoint input ΔC, instead of an Array{Float64,0}. This is strange, as the pullback rule of tensorscalar does transform a scalar back into a Array{Float64,0} value. Is Zygote trying to be smart here, by replacing this zero-dimensional array with a Tuple in between these two steps? Or is it somehow simply sidestepping the pullback rule of tensorscalar?

pbrehmer commented 8 months ago

I just quickly checked: The pullback rule of tensorscalar is being called and it does indeed output an Array{Float64,0}. Then ΔC somehow gets converted to a Tuple{Float64} in tensortrace!. Is this perhaps also related to the @thunk and unthunk of ChainRulesCore?

What's also weird is that pC is just an empty tuple ((), ()) such that numind(pC) in tensortrace! errors as well, when I force ΔC to be of the correct type. So I don't quite understand what is happening under the hood of Zygote here.

Jutho commented 8 months ago

So apparently it's not Zygote's fault, but ours.

julia> scale(fill(0.5), 0.3)
(0.15,)

scale is the culprit here that changes the Array{Float64,0} into a Tuple{Float64}. Will investigate further, this should very clearly not happen!

Addendum: The reason seems to be with broadcasting; broadcasting is used to implement scale(::AbstractArray, ::Number), but it treats zero-dimensional arrays specially and does not preserve them.

Jutho commented 8 months ago

Ok, so this fixes it on my side: https://github.com/Jutho/VectorInterface.jl/commit/d4d11298ecc6feae0a33588563709830798ac87f

As soon as the tests turn green, I will tag a new release.

pbrehmer commented 8 months ago

Great, thanks for the fast fix!

lkdvos commented 6 months ago

Fixed and tagged in v4.1