JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
434 stars 89 forks source link

Missing rules for Tridiagonal #713

Open ChrisRackauckas opened 1 year ago

ChrisRackauckas commented 1 year ago

See https://discourse.julialang.org/t/optimization-jl-datainterpolations-jl-and-gradients/97676/2?u=chrisrackauckas

I just pulled out some old code:

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl, y -> Tridiagonal(dl, zeros(length(d)), zeros(length(du)),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d, y -> Tridiagonal(zeros(length(dl)), d, zeros(length(du)),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl, y -> Tridiagonal(zeros(length(dl)), zeros(length(d), du),)
ZygoteRules.@adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end, 1:end-1]), diag(p̄), diag(p̄[1:end-1, 2:end]))

But I presume that it should be updated to ProjectTo stuff and added here.

mcabbott commented 1 year ago

This is what https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446 was meant to solve. Right now it gives this error but haven't investigated further:

julia> # Solve
       result = solve(prob, ADAM(0.001, (0.9, 0.999)), maxiters=1000)
ERROR: ArgumentError: new: too few arguments (expected 4)
Stacktrace:
  [1] __new__
    @ ~/.julia/packages/Zygote/xGkZ5/src/tools/builtins.jl:9 [inlined]
  [2] adjoint
    @ ~/.julia/packages/Zygote/xGkZ5/src/lib/lib.jl:293 [inlined]
  [3] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
  [4] _pullback
    @ /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/tridiag.jl:498 [inlined]
  [5] _pullback(::Zygote.Context{false}, ::Type{Tridiagonal{Float64, Vector{Float64}}}, ::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/xGkZ5/src/compiler/interface2.jl:0
  [6] _pullback
    @ /Applications/Julia-1.9.app/Contents/Resources/julia/share/julia/stdlib/v1.9/LinearAlgebra/src/tridiag.jl:533 [inlined]
  [7] _pullback
    @ ~/.julia/packages/DataInterpolations/ivHqg/src/interpolation_caches.jl:160 [inlined]
  [8] _pullback(::Zygote.Context{false}, ::Type{CubicSpline}, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/xGkZ5/src/compiler/interface2.jl:0
  [9] _pullback
    @ ./REPL[11]:3 [inlined]
 [10] _pullback(::Zygote.Context{false}, ::typeof(obj), ::Vector{Float64}, ::Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/xGkZ5/src/compiler/interface2.jl:0
 [11] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [12] adjoint
    @ ~/.julia/packages/Zygote/xGkZ5/src/lib/lib.jl:203 [inlined]
 [13] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [14] _pullback
    @ ~/.julia/packages/SciMLBase/VdcHg/src/scimlfunctions.jl:3626 [inlined]
 [15] _pullback(::Zygote.Context{false}, ::OptimizationFunction{true, Optimization.AutoZygote, typeof(obj), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, ::Vector{Float64}, ::Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/xGkZ5/src/compiler/interface2.jl:0
 [16] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [17] adjoint
    @ ~/.julia/packages/Zygote/xGkZ5/src/lib/lib.jl:203 [inlined]
 [18] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [19] _pullback
    @ ~/.julia/packages/Optimization/RHDsr/src/function/zygote.jl:31 [inlined]
 [20] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#138#147"{OptimizationFunction{true, Optimization.AutoZygote, typeof(obj), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/xGkZ5/src/compiler/interface2.jl:0
 [21] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [22] adjoint
    @ ~/.julia/packages/Zygote/xGkZ5/src/lib/lib.jl:203 [inlined]
 [23] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [24] _pullback
    @ ~/.julia/packages/Optimization/RHDsr/src/function/zygote.jl:35 [inlined]
 [25] _pullback(ctx::Zygote.Context{false}, f::Optimization.var"#140#149"{Tuple{}, Optimization.var"#138#147"{OptimizationFunction{true, Optimization.AutoZygote, typeof(obj), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/xGkZ5/src/compiler/interface2.jl:0
 [26] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/xGkZ5/src/compiler/interface.jl:44
 [27] pullback
    @ ~/.julia/packages/Zygote/xGkZ5/src/compiler/interface.jl:42 [inlined]
 [28] gradient(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/xGkZ5/src/compiler/interface.jl:96
 [29] (::Optimization.var"#139#148"{Optimization.var"#138#147"{OptimizationFunction{true, Optimization.AutoZygote, typeof(obj), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}})(::Vector{Float64}, ::Vector{Float64})
    @ Optimization ~/.julia/packages/Optimization/RHDsr/src/function/zygote.jl:33
 [30] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:31 [inlined]
 [31] macro expansion
    @ ~/.julia/packages/Optimization/RHDsr/src/utils.jl:37 [inlined]
 [32] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, typeof(obj), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Vector{Float64}, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::Optimisers.Adam{Float64}, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:30
 [33] __solve (repeats 2 times)
    @ ~/.julia/packages/OptimizationOptimisers/FWIuf/src/OptimizationOptimisers.jl:7 [inlined]
 [34] #solve#553
    @ ~/.julia/packages/SciMLBase/VdcHg/src/solve.jl:86 [inlined]
 [35] top-level scope

julia> VERSION
v"1.9.0-rc1"

(jl_iOCibG) pkg> st
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_iOCibG/Project.toml`
  [d360d2e6] ChainRulesCore v1.15.0 `https://github.com/mcabbott/ChainRulesCore.jl#unstructural`
  [82cc6244] DataInterpolations v4.0.0
  [31c24e10] Distributions v0.25.87
  [7f7a1694] Optimization v3.13.1
  [253f991c] OptimizationFlux v0.1.4
  [42dfb2eb] OptimizationOptimisers v0.1.2