JuliaDiff / SparseDiffTools.jl

Fast jacobian computation through sparsity exploitation and matrix coloring
MIT License
237 stars 41 forks source link

`autoback_hesvec` fails with `SparseMatrixCSC` #180

Closed newalexander closed 6 months ago

newalexander commented 2 years ago

Hello! I can't seem to use autoback_hesvec when the function includes a sparse CSC matrix.

using Zygote, SparseArrays, SparseDiffTools

x, t = rand(Float32, 5), rand(Float32, 5)
A = sprand(Float32, 5, 5, 0.5)
loss(_x) = sum(tanh.(A * _x))

numback_hesvec(loss, x, t)  # works
autoback_hesvec(loss, x, t)  # fails

with the full message

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1})
Closest candidates are:
  (::Type{T})(::Real, ::RoundingMode) where T<:AbstractFloat at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/rounding.jl:200
  (::Type{T})(::T) where T<:Number at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/boot.jl:770
  (::Type{T})(::AbstractChar) where T<:Union{AbstractChar, Number} at /Applications/Julia-1.7.app/Contents/Resources/julia/share/julia/base/char.jl:50
  ...
Stacktrace:
  [1] convert(#unused#::Type{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1})
    @ Base ./number.jl:7
  [2] setindex!(A::Vector{Float64}, x::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1}, i1::Int64)
    @ Base ./array.jl:903
  [3] (::ChainRulesCore.ProjectTo{SparseMatrixCSC, NamedTuple{(:element, :axes, :rowval, :nzranges, :colptr), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Vector{Int64}, Vector{UnitRange{Int64}}, Vector{Int64}}}})(dx::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/uxrij/src/projection.jl:580
  [4] #1334
    @ ~/.julia/packages/ChainRules/3HAQW/src/rulesets/Base/arraymath.jl:36 [inlined]
  [5] unthunk
    @ ~/.julia/packages/ChainRulesCore/uxrij/src/tangent_types/thunks.jl:197 [inlined]
  [6] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:104 [inlined]
  [7] map
    @ ./tuple.jl:223 [inlined]
  [8] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:105 [inlined]
  [9] ZBack
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/chainrules.jl:204 [inlined]
 [10] Pullback
    @ ./REPL[5]:1 [inlined]
 [11] (::typeof(∂(loss)))(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#57#58"{typeof(∂(loss))})(Δ::ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float64, 1})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:41
 [13] gradient(f::Function, args::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:76
 [14] (::SparseDiffTools.var"#78#79"{typeof(loss)})(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float32}, Float32, 1}})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/9lSLn/src/differentiation/jaches_products_zygote.jl:39
 [15] autoback_hesvec(f::Function, x::Vector{Float32}, v::Vector{Float32})
    @ SparseDiffTools ~/.julia/packages/SparseDiffTools/9lSLn/src/differentiation/jaches_products_zygote.jl:41
 [16] top-level scope
    @ REPL[7]:1

This is with [47a9eef4] SparseDiffTools v1.20.0 and [082447d4] ChainRules v1.26.0. The issue was originally noticed here, where it was suggested to be reported here instead.

ChrisRackauckas commented 2 years ago

[3] (::ChainRulesCore.ProjectTo{SparseMatrixCSC, NamedTuple{(:element, :axes, :rowval, :nzranges, :colptr), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Vector{Int64}, Vector{UnitRange{Int64}}, Vector{Int64}}}})(dx::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{DataType, Float64}, Float64, 1}}) @ ChainRulesCore ~/.julia/packages/ChainRulesCore/uxrij/src/projection.jl:580

It seems like it's creating a SparseMatrixCSC of Float64. @DhairyaLGandhi how would I investigate why?

tcovert commented 6 months ago

any update on this? I too encountered this problem today.

ChrisRackauckas commented 6 months ago

Hi think this issue should be reported to SparseDiffTools.jl and probably fixed in ChainRules.jl

was a bit of an odd response in https://github.com/CarloLucibello/GraphNeuralNetworks.jl/issues/125 that lead to this issue. The issue should be reported in ChainRules.jl because it's ChainRules.jl that needs a fix. There's no code that could/should be changed to handle this here, any code change would have to happen in ChainRules.jl to fix this. So waiting on a SparseDiffTools.jl change is not going to lead anywhere 😅.

Here's a version of the issue that uses no code from SparseDiffTools:

using Zygote, SparseArrays, ForwardDiff

x, v = rand(Float32, 5), rand(Float32, 5)
A = sprand(Float32, 5, 5, 0.5)
loss(_x) = sum(tanh.(A * _x))

T = typeof(ForwardDiff.Tag(nothing, eltype(x)))
y = ForwardDiff.Dual{T, eltype(x), 1}.(x, ForwardDiff.Partials.(tuple.(reshape(v, size(x)))))
g = x -> first(Zygote.gradient(loss, x))
y = _default_autoback_hesvec_cache(x, v)
ForwardDiff.partials.(g(y), 1)

Let's take that upstream.

ChrisRackauckas commented 6 months ago

https://github.com/JuliaDiff/ChainRulesCore.jl/issues/648 is the upstream issue.