JuliaDiff / AbstractDifferentiation.jl

An abstract interface for automatic differentiation.
https://juliadiff.org/AbstractDifferentiation.jl/
MIT License
135 stars 18 forks source link

`pushforward_function` and `pullback_function` are confused by tuples vs single input #99

Open gdalle opened 1 year ago

gdalle commented 1 year ago

The setup:

```julia julia> import AbstractDifferentiation as AD julia> using ForwardDiff: ForwardDiff julia> using Zygote: Zygote julia> b1 = AD.ZygoteBackend(); julia> b2 = AD.ForwardDiffBackend(); julia> f(x) = x .^ 2; julia> x = rand(3) 3-element Vector{Float64}: 0.4953469957333393 0.16373195021545772 0.9601871509472656 julia> y = f(x) 3-element Vector{Float64}: 0.24536864618204485 0.026808151521357126 0.921959364844227 julia> dx = rand(size(x)...) 3-element Vector{Float64}: 0.5968881542176618 0.05494767011762569 0.18061398390944328 julia> dy = rand(size(y)...) 3-element Vector{Float64}: 0.9491707280920829 0.2878716471988746 0.15674572721525504 ```

A pushforward with Zygote backend doesn't accept a single array as input.

```julia julia> pf1 = AD.pushforward_function(b1, f, x); julia> pf2 = AD.pushforward_function(b2, f, x); julia> pf1((dx,)) # works ([0.5913335079610738, 0.017993378376308967, 0.3468464532624872],) julia> pf1(dx) # fails but shouldn't ERROR: ArgumentError: Tuple contains 3 elements, must contain exactly 1 element Stacktrace: [1] only(x::Tuple{Float64, Float64, Float64}) @ Base.Iterators ./iterators.jl:1531 [2] (::AbstractDifferentiation.var"#14#16"{Vector{Float64}, typeof(f), Tuple{Vector{Float64}}})(::Float64, ::Vararg{Float64}) @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:172 [3] (::AbstractDifferentiation.var"#25#27"{AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context{false}}}, AbstractDifferentiation.var"#14#16"{Vector{Float64}, typeof(f), Tuple{Vector{Float64}}}, Tuple{Float64, Float64, Float64}})(ws::Nothing) @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:249 [4] jacobian(::AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context{false}}}, ::Function, ::Float64, ::Float64, ::Vararg{Float64}) @ AbstractDifferentiationChainRulesCoreExt ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:551 [5] (::AbstractDifferentiation.var"#13#15"{AbstractDifferentiation.ReverseRuleConfigBackend{Zygote.ZygoteRuleConfig{Zygote.Context{false}}}, typeof(f), Tuple{Vector{Float64}}})(ds::Vector{Float64}) @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:166 [6] top-level scope @ ~/Work/GitHub/Julia/ImplicitDifferentiation.jl/test/playground.jl:54 julia> pf2((dx,)) # works ([0.5913335079610738, 0.017993378376308967, 0.3468464532624872],) julia> pf2(dx) # works ([0.5913335079610738, 0.017993378376308967, 0.3468464532624872],) ```

A pullback with ForwardDiff backend doesn't accept a tuple as input:

```julia julia> pb1 = AD.pullback_function(b1, f, x); julia> pb2 = AD.pullback_function(b2, f, x); julia> pb1(dy) # works ([0.9403377371968791, 0.09426757241521588, 0.301010466475946],) julia> pb1((dy,)) # works ([0.9403377371968791, 0.09426757241521588, 0.301010466475946],) julia> pb2(dy) # works ([0.9403377371968791, 0.09426757241521588, 0.301010466475946],) julia> pb2((dy,)) # fails but shouldn't ERROR: AssertionError: length(vs) == length(ws) Stacktrace: [1] (::AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)})(xs::Vector{ForwardDiff.Dual{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3}}) @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:231 [2] vector_mode_dual_eval! @ ~/.julia/packages/ForwardDiff/vXysl/src/apiutils.jl:24 [inlined] [3] vector_mode_gradient(f::AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3}}}) @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:89 [4] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3}}}, ::Val{true}) @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:19 [5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3, Vector{ForwardDiff.Dual{ForwardDiff.Tag{AbstractDifferentiation.var"#88#90"{Tuple{Vector{Float64}}, AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f)}, Float64}, Float64, 3}}}) @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17 [6] gradient(ba::AbstractDifferentiation.ForwardDiffBackend{Nothing}, f::Function, x::Vector{Float64}) @ AbstractDifferentiationForwardDiffExt ~/Work/GitHub/Julia/AbstractDifferentiation.jl/ext/AbstractDifferentiationForwardDiffExt.jl:46 [7] (::AbstractDifferentiation.var"#87#89"{AbstractDifferentiation.ForwardDiffBackend{Nothing}, typeof(f), Tuple{Vector{Float64}}})(ws::Tuple{Vector{Float64}}) @ AbstractDifferentiation ~/Work/GitHub/Julia/AbstractDifferentiation.jl/src/AbstractDifferentiation.jl:224 [8] top-level scope @ ~/Work/GitHub/Julia/ImplicitDifferentiation.jl/test/playground.jl:63 ```

Both of these uses are documented in the README