SciML / Integrals.jl

A common interface for quadrature and numerical integration for the SciML scientific machine learning organization
https://docs.sciml.ai/Integrals/stable/
MIT License
221 stars 30 forks source link

Zygote.gradient failed with CubaCuhre and batch !=0 #49

Open KirillZubov opened 3 years ago

KirillZubov commented 3 years ago
using Quadrature, ForwardDiff, FiniteDiff, Zygote, Cuba
f(x,p) = sum(sin.(x .* p))
lb = ones(3)
ub = 3ones(3)
p = [1.5,2.0,3.0]

prob = QuadratureProblem(f,lb,ub,p; batch=0)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]

function testf(p)
    prob = QuadratureProblem(f,lb,ub,p, batch=0)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
end
dp1 = Zygote.gradient(testf,p)
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

prob = QuadratureProblem(f,lb,ub,p; batch=10)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]

function testf(p)
    prob = QuadratureProblem(f,lb,ub,p, batch=10)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=100)[1]
end
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

dp1 = Zygote.gradient(testf,p)

ERROR: MethodError: Cannot `convert` an object of type Array{Float64,1} to an object of type Float64
Closest candidates are:
  convert(::Type{T}, ::ArrayInterface.StaticInt{N}) where {T<:Number, N} at /Users/kirill/.julia/packages/ArrayInterface/rw2kK/src/static.jl:18
  convert(::Type{R}, ::T) where {R<:Real, T<:ReverseDiff.TrackedReal} at /Users/kirill/.julia/packages/ReverseDiff/jFRo1/src/tracked.jl:255
  convert(::Type{T}, ::Unitful.Quantity) where T<:Real at /Users/kirill/.julia/packages/Unitful/1t88N/src/conversion.jl:145
  ...
Stacktrace:
 [1] setindex! at ./array.jl:849 [inlined]
 [2] macro expansion at ./multidimensional.jl:802 [inlined]
 [3] macro expansion at ./cartesian.jl:64 [inlined]
 [4] macro expansion at ./multidimensional.jl:797 [inlined]
 [5] _unsafe_setindex!(::IndexLinear, ::Array{Float64,2}, ::Array{Array{Float64,1},1}, ::Base.Slice{Base.OneTo{Int64}}, ::Int64) at ./multidimensional.jl:789
 [6] _setindex! at ./multidimensional.jl:785 [inlined]
 [7] setindex!(::Array{Float64,2}, ::Array{Array{Float64,1},1}, ::Function, ::Int64) at ./abstractarray.jl:1153
 [8] (::Quadrature.var"#46#57"{QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}})(::Array{Float64,2}, ::Array{Float64,1}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:524
 [9] __solvebp_call(::QuadratureProblem{false,Array{Float64,1},Quadrature.var"#46#57"{QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}},Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}}}, ::CubaCuhre, ::Quadrature.ReCallVJP{Quadrature.ZygoteVJP}, ::Array{Float64,1}, ::Array{Float64,1}, ::Array{Float64,1}; reltol::Float64, abstol::Float64, maxiters::Int64, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:437
 [10] (::Quadrature.var"#quadrature_adjoint#52"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:546
 [11] #65#back at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [12] #150 at /Users/kirill/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
 [13] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{Quadrature.var"#65#back#64"{Quadrature.var"#quadrature_adjoint#52"{Base.Iterators.Pairs{Symbol,Real,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:reltol, :abstol, :maxiters),Tuple{Float64,Float64,Int64}}},QuadratureProblem{false,Array{Float64,1},typeof(f),Array{Float64,1},Array{Float64,1},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}},CubaCuhre,Quadrature.ReCallVJP{Quadrature.ZygoteVJP},Array{Float64,1},Array{Float64,1},Array{Float64,1},Tuple{}}},Tuple{NTuple{8,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [14] #solve#10 at /Users/kirill/.julia/packages/Quadrature/ZmWGy/src/Quadrature.jl:149 [inlined]
 [15] (::typeof(∂(#solve#10)))(::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [16] #150 at /Users/kirill/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
 [17] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(#solve#10)),Tuple{NTuple{5,Nothing},Tuple{}}}})(::Array{Float64,1}) at /Users/kirill/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [18] (::typeof(∂(solve##kw)))(::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [19] testf at ./none:3 [inlined]
 [20] (::typeof(∂(testf)))(::Float64) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#41#42"{typeof(∂(testf))})(::Float64) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:45
 [22] gradient(::Function, ::Array{Float64,1}) at /Users/kirill/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
 [23] top-level scope at none:1
lxvm commented 4 months ago

I think there is a bug in the MWE since a scalar-valued f is incompatible with batching. Namely, the batch integrand should return a vector whose length matches the last axis of the input points (see the FAQ for more details).

I've adapted the MWE to the current version of Integrals, modified the integrand to do what I think was intended, and can confirm it works on the master branch

using Integrals, ForwardDiff, FiniteDiff, Zygote, Cuba
f(x,p) = sum(sin.(x .* p); dims=1)
lb = ones(3)
ub = 3ones(3)
p = [1.5,2.0,3.0]

prob = IntegralProblem(f,lb,ub,p)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]

function testf(p)
    prob = IntegralProblem(f,lb,ub,p)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
end
dp1 = Zygote.gradient(testf,p)
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

prob = IntegralProblem(f,lb,ub,p; batch=10)
sol = solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]

function testf(p)
    prob = IntegralProblem(f,lb,ub,p, batch=10)
    solve(prob,CubaCuhre(),reltol= 1e-4,abstol= 1e-4,maxiters=1000)[1]
end
dp2 = FiniteDiff.finite_difference_gradient(testf,p)
dp3 = ForwardDiff.gradient(testf,p)
# work fine

dp1 = Zygote.gradient(testf,p)

Since there are some bugs in the current release that affect CubaCuhre and they are fixed on the master branch, I'll wait to close the issue until the next release.