SciML / DiffEqFlux.jl

Pre-built implicit layer architectures with O(1) backprop, GPUs, and stiff+non-stiff DE solvers, demonstrating scientific machine learning (SciML) and physics-informed machine learning methods
https://docs.sciml.ai/DiffEqFlux/stable
MIT License
870 stars 157 forks source link

Higher-order derivatives of ffjord on gpu fail #652

Open ccasert opened 2 years ago

ccasert commented 2 years ago

I need to calculate the Laplacian of the densities modelled by a normalizing flow w.r.t. to the inputs. On CPU, I can e.g. use the following code (which works, but seems to scale poorly with the number of samples)

nn = Chain(
    Dense(2, 32, tanh),
    Dense(32, 2),
) |> f32

tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())

function loss(x)
    logpx, λ₁, λ₂ = ffjord_mdl(x)
    return logpx
end

function lapl(x)
    return Zygote.diaghessian(x->sum(loss(x)), x)
end

data_dist = Normal(0.0f0, 1.0f0)
train_data = rand(data_dist, 2, 10)
lapl(train_data)

However, when I attempt to run this code on GPU

nn = Chain(
    Dense(2, 32, tanh),
    Dense(32, 2),
) |> gpu

tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())

function loss(x)
    e = randn(Float32, size(x)) |> gpu
    logpx, λ₁, λ₂ = ffjord_mdl(x, ffjord_mdl.p, e)
    return logpx
end

function lapl(x)
    return Zygote.diaghessian(x->sum(loss(x)), x)
end

data_dist = Normal(0.0f0, 1.0f0)
train_data = gpu(rand(data_dist, 2, 10))
lapl(train_data)

I get the following error:

ERROR: MethodError: no method matching cudnnDataType(::Type{ForwardDiff.Dual{Nothing, Float32, 12}})
Closest candidates are:
  cudnnDataType(::Type{Float16}) at /data/packages/CUDA/YpW0k/lib/cudnn/util.jl:7
  cudnnDataType(::Type{Float32}) at /data/packages/CUDA/YpW0k/lib/cudnn/util.jl:8
  cudnnDataType(::Type{Float64}) at /data/packages/CUDA/YpW0k/lib/cudnn/util.jl:9
  ...
Stacktrace:
  [1] CUDA.CUDNN.cudnnTensorDescriptor(array::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}; format::CUDA.CUDNN.cudnnTensorFormat_t, dims::Vector{Int32})
    @ CUDA.CUDNN /data/packages/CUDA/YpW0k/lib/cudnn/tensor.jl:9
  [2] CUDA.CUDNN.cudnnTensorDescriptor(array::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ CUDA.CUDNN /data/packages/CUDA/YpW0k/lib/cudnn/tensor.jl:8
  [3] cudnnActivationForward!(y::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, x::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}; o::Base.Iterators.Pairs{Symbol, CUDA.CUDNN.cudnnActivationMode_t, Tuple{Symbol}, NamedTuple{(:mode,), Tuple{CUDA.CUDNN.cudnnActivationMode_t}}})
    @ CUDA.CUDNN /data/packages/CUDA/YpW0k/lib/cudnn/activation.jl:22
  [4] (::NNlibCUDA.var"#64#68")(src::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, dst::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ NNlibCUDA /data/packages/NNlibCUDA/gWBCU/src/cudnn/activations.jl:10
  [5] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(tanh), Tuple{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}})
    @ NNlibCUDA /data/packages/NNlibCUDA/gWBCU/src/cudnn/activations.jl:30
  [6] adjoint
    @ /data/packages/Zygote/AlLTp/src/lib/broadcast.jl:102 [inlined]
  [7] _pullback(__context__::Zygote.Context, 641::typeof(Base.Broadcast.broadcasted), 642::typeof(tanh), x::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65
  [8] _pullback
    @ /data/packages/Flux/BPPNj/src/layers/basic.jl:158 [inlined]
  [9] _pullback(ctx::Zygote.Context, f::Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [10] _pullback
    @ /data/packages/Flux/BPPNj/src/layers/basic.jl:47 [inlined]
 [11] _pullback(::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [12] _pullback
    @ /data/packages/Flux/BPPNj/src/layers/basic.jl:49 [inlined]
 [13] _pullback(ctx::Zygote.Context, f::Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [14] _pullback
    @ /data/packages/Zygote/AlLTp/src/compiler/interface.jl:34 [inlined]
 [15] pullback(f::Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface.jl:40
 [16] ffjord(u::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, p::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, t::Float32, re::Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, e::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}; regularize::Bool, monte_carlo::Bool)
    @ DiffEqFlux /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:186
 [17] ffjord_
    @ /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:204 [inlined]
 [18] ODEFunction
    @ /data/packages/SciMLBase/x3z0g/src/scimlfunctions.jl:334 [inlined]
 [19] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, false, CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Nothing, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float32, Float32, Float32, Float32, Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}, ODESolution{ForwardDiff.Dual{Nothing, Float32, 12}, 3, Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}}, ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}, Vector{Float32}, Vector{Vector{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{ForwardDiff.Dual{Nothing, Float32, 12}, ForwardDiff.Dual{Nothing, Float32, 12}, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Bool, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ForwardDiff.Dual{Nothing, Float32, 12}, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32})
    @ OrdinaryDiffEq /data/packages/OrdinaryDiffEq/JsAS0/src/perform_step/low_order_rk_perform_step.jl:569
 [20] __init(prob::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol}, NamedTuple{(:save_noise,), Tuple{Bool}}})
    @ OrdinaryDiffEq /data/packages/OrdinaryDiffEq/JsAS0/src/solve.jl:456
 [21] #__solve#493
    @ /data/packages/OrdinaryDiffEq/JsAS0/src/solve.jl:4 [inlined]
 [22] #solve_call#42
    @ /data/packages/DiffEqBase/b1nST/src/solve.jl:61 [inlined]
 [23] solve_up(prob::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, p::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, args::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; kwargs::Base.Iterators.Pairs{Symbol, Bool, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:save_noise, :save_start, :save_end), Tuple{Bool, Bool, Bool}}})
    @ DiffEqBase /data/packages/DiffEqBase/b1nST/src/solve.jl:87
 [24] #solve#43
    @ /data/packages/DiffEqBase/b1nST/src/solve.jl:73 [inlined]
 [25] _concrete_solve_adjoint(::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::InterpolatingAdjoint{0, true, Val{:central}, Bool, Bool}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}; save_start::Bool, save_end::Bool, saveat::Vector{Float32}, save_idxs::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity /data/packages/DiffEqSensitivity/uakCr/src/concrete_solve.jl:151
 [26] _concrete_solve_adjoint
    @ /data/packages/DiffEqSensitivity/uakCr/src/concrete_solve.jl:131 [inlined]
 [27] #_solve_adjoint#61
    @ /data/packages/DiffEqBase/b1nST/src/solve.jl:347 [inlined]
 [28] _solve_adjoint
    @ /data/packages/DiffEqBase/b1nST/src/solve.jl:322 [inlined]
 [29] #rrule#59
    @ /data/packages/DiffEqBase/b1nST/src/solve.jl:310 [inlined]
 [30] rrule
    @ /data/packages/DiffEqBase/b1nST/src/solve.jl:310 [inlined]
 [31] rrule
    @ /data/packages/ChainRulesCore/7ZiwT/src/rules.jl:134 [inlined]
 [32] chain_rrule
    @ /data/packages/Zygote/AlLTp/src/compiler/chainrules.jl:216 [inlined]
 [33] macro expansion
    @ /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0 [inlined]
 [34] _pullback
    @ /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:9 [inlined]
 [35] _apply
    @ ./boot.jl:804 [inlined]
 [36] adjoint
    @ /data/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
 [37] _pullback
    @ /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [38] _pullback
    @ /data/packages/DiffEqBase/b1nST/src/solve.jl:73 [inlined]
 [39] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#43", ::InterpolatingAdjoint{0, true, Val{:central}, Bool, Bool}, ::Nothing, ::Nothing, ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(solve), ::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [40] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [41] adjoint
    @ /data/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
 [42] _pullback
    @ /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [43] _pullback
    @ /data/packages/DiffEqBase/b1nST/src/solve.jl:68 [inlined]
 [44] _pullback(::Zygote.Context, ::CommonSolve.var"#solve##kw", ::NamedTuple{(:sensealg,), Tuple{InterpolatingAdjoint{0, true, Val{:central}, Bool, Bool}}}, ::typeof(solve), ::ODEProblem{CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#61"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [45] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [46] adjoint
    @ /data/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
 [47] _pullback
    @ /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [48] _pullback
    @ /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:218 [inlined]
 [49] _pullback(::Zygote.Context, ::DiffEqFlux.var"##forward_ffjord#56", ::Bool, ::Bool, ::typeof(DiffEqFlux.forward_ffjord), ::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [50] _pullback
    @ /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:202 [inlined]
 [51] _pullback(::Zygote.Context, ::typeof(DiffEqFlux.forward_ffjord), ::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [52] _apply(::Function, ::Vararg{Any, N} where N)
    @ Core ./boot.jl:804
 [53] adjoint
    @ /data/packages/Zygote/AlLTp/src/lib/lib.jl:200 [inlined]
 [54] _pullback
    @ /data/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [55] _pullback
    @ /data/packages/DiffEqFlux/jpIWG/src/ffjord.jl:198 [inlined]
--- the last 5 lines are repeated 1 more time ---
 [61] _pullback(::Zygote.Context, ::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [62] _pullback
    @ ./REPL[6]:3 [inlined]
 [63] _pullback(ctx::Zygote.Context, f::typeof(loss), args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [64] _pullback
    @ ./REPL[7]:2 [inlined]
 [65] _pullback(ctx::Zygote.Context, f::var"#1#2", args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface2.jl:0
 [66] _pullback(f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface.jl:34
 [67] pullback(f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface.jl:40
 [68] gradient(f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/compiler/interface.jl:75
 [69] (::Zygote.var"#105#108"{Int64, Val{1}, var"#1#2", Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})(x::CuArray{ForwardDiff.Dual{Nothing, Float32, 12}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/lib/grad.jl:272
 [70] forward_diag(f::Zygote.var"#105#108"{Int64, Val{1}, var"#1#2", Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, #unused#::Val{12})
    @ Zygote /data/packages/Zygote/AlLTp/src/lib/forward.jl:65
 [71] forward_diag(f::Function, x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote /data/packages/Zygote/AlLTp/src/lib/forward.jl:80
 [72] #104
    @ /data/packages/Zygote/AlLTp/src/lib/grad.jl:272 [inlined]
 [73] ntuple
    @ ./ntuple.jl:19 [inlined]
 [74] diaghessian
    @ /data/packages/Zygote/AlLTp/src/lib/grad.jl:269 [inlined]
 [75] lapl(x::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Main ./REPL[7]:2
 [76] top-level scope
    @ ./timing.jl:210 [inlined]
 [77] top-level scope
    @ ./REPL[11]:0

These are my package versions:

(@v1.6) pkg> status
      Status `/data/environments/v1.6/Project.toml`
  [6e4b80f9] BenchmarkTools v1.2.0
  [052768ef] CUDA v3.5.0
  [aae7a2af] DiffEqFlux v1.44.0
  [41bf760c] DiffEqSensitivity v6.60.3
  [0c46a032] DifferentialEquations v6.20.0
  [31c24e10] Distributions v0.25.29
  [587475ba] Flux v0.12.8
  [a75be94c] GalacticOptim v2.2.0
  [429524aa] Optim v1.5.0
  [1dea7af3] OrdinaryDiffEq v5.67.0
  [91a5bcdd] Plots v1.23.6
  [e88e6eb3] Zygote v0.6.30
  [de0858da] Printf
  [10745b16] Statistics

Any help would be appreciated!

ChrisRackauckas commented 2 years ago

@DhairyaLGandhi any good way around this?

DomCRose commented 2 years ago

Just curious whether there has been any movement on this? Or whether there could be an alternative to getting the laplacian which works by using other functions / AD packages.

DomCRose commented 2 years ago

Not sure if its related, but if I simply try to call the loss of this FFJORD code I get a scalar indexing on a GPU array error. Seems to point to the [:, :, end] slice on the solve in forward_ffjord.

Code:

nn = Chain(
    Dense(2, 32, tanh),
    Dense(32, 2),
) |> gpu

tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())

function loss(x)
    e = randn(Float32, size(x)) |> gpu
    logpx, λ₁, λ₂ = ffjord_mdl(x, ffjord_mdl.p, e)
    return logpx
end

function lapl(x)
    return Zygote.diaghessian(x -> sum(loss(x)), x)
end

data_dist = Normal(0.0f0, 1.0f0)
train_data = gpu(rand(data_dist, 2))
loss(train_data)

Error:

ERROR: LoadError: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base .\error.jl:33
  [2] assertscalar(op::String)
    @ GPUArrays C:\Users\domin\.julia\packages\GPUArrays\gkF6S\src\host\indexing.jl:53
  [3] getindex(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Int64, ::Int64)
    @ GPUArrays C:\Users\domin\.julia\packages\GPUArrays\gkF6S\src\host\indexing.jl:86
  [4] getindex
    @ C:\Users\domin\.julia\packages\RecursiveArrayTools\gr5FR\src\vector_of_array.jl:164 [inlined]
  [5] macro expansion
    @ .\multidimensional.jl:860 [inlined]
  [6] macro expansion
    @ .\cartesian.jl:64 [inlined]
  [7] macro expansion
    @ .\multidimensional.jl:855 [inlined]
  [8] _unsafe_getindex!
    @ .\multidimensional.jl:868 [inlined]
  [9] _unsafe_getindex(::IndexCartesian, ::VectorOfArray{Float32, 3, Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, ::Base.Slice{Base.OneTo{Int64}}, ::Base.Slice{Base.OneTo{Int64}}, ::Int64)
    @ Base .\multidimensional.jl:846
 [10] _getindex
    @ .\multidimensional.jl:832 [inlined]
 [11] getindex
    @ .\abstractarray.jl:1170 [inlined]
 [12] getindex(::ODESolution{Float32, 3, Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, ODEProblem{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ODEFunction{false, DiffEqFlux.var"#ffjord_#63"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{false, DiffEqFlux.var"#ffjord_#63"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Vector{Float32}, Vector{Vector{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats}, ::Colon, ::Colon, ::Int64)
    @ SciMLBase C:\Users\domin\.julia\packages\SciMLBase\jj8Ix\src\solutions\solution_interface.jl:33
 [13] forward_ffjord(n::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, p::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, e::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}; regularize::Bool, monte_carlo::Bool)
    @ DiffEqFlux C:\Users\domin\.julia\packages\DiffEqFlux\w4Zm0\src\ffjord.jl:219
 [14] forward_ffjord(n::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, p::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, e::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ DiffEqFlux C:\Users\domin\.julia\packages\DiffEqFlux\w4Zm0\src\ffjord.jl:203
 [15] (::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}})(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, N} where N; kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqFlux C:\Users\domin\.julia\packages\DiffEqFlux\w4Zm0\src\ffjord.jl:199
 [16] (::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}, MvNormal{Float32, PDMats.PDiagMat{Float32, Vector{Float32}}, Vector{Float32}}, Tuple{Float32, Float32}, Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}})(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, N} where N)
    @ DiffEqFlux C:\Users\domin\.julia\packages\DiffEqFlux\w4Zm0\src\ffjord.jl:199
 [17] loss(x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Main c:\Users\domin\Dropbox\code_practice\julia\diffeqflux\ffjord_test.jl:38
 [18] top-level scope
    @ c:\Users\domin\Dropbox\code_practice\julia\diffeqflux\ffjord_test.jl:48
in expression starting at c:\Users\domin\Dropbox\code_practice\julia\diffeqflux\ffjord_test.jl:48

Environment status:

[052768ef] CUDA v3.6.4
[aae7a2af] DiffEqFlux v1.44.1
[0c46a032] DifferentialEquations v7.1.0
[31c24e10] Distributions v0.25.38
ChrisRackauckas commented 2 years ago

https://github.com/SciML/DiffEqFlux.jl/pull/614 is probably the solution when it's finished.

DomCRose commented 2 years ago

A small update: https://github.com/FluxML/NNlibCUDA.jl/pull/48 fixes the original bug in this issue. However, there remains another bug (that now looks Zygote related) in the diaghessian call. The scalar indexing in the forward call of the loss also remains.