SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
333 stars 71 forks source link

Does ZygoteVJP() support training Neural ODE with Discretecallback on GPU? #1093

Open yunan-l opened 3 months ago

yunan-l commented 3 months ago

Hi, I tried to train a Neural ODE with Discretecallback with sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true) on GPU, but said:

Only `ReverseDiffVJP` and `EnzymeVJP` are currently compatible with continuous adjoint sensitivity methods for hybrid DEs. Please select `ReverseDiffVJP` or `EnzymeVJP` as `autojacvec`.

Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] _setup_reverse_callbacks(cb::DiscreteCallback{DiffEqCallbacks.var"#109#113"{Vector{Float32}}, SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#111#115"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#112#116"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}, affect::SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#111#115"{typeof(affect!)}, Nothing, Int64}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, dgdu::Function, dgdp::Nothing, loss_ref::Base.RefValue{Int64}, terminated::Bool)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/PstNN/src/callback_tracking.jl:244
  [3] _setup_reverse_callbacks(cb::DiscreteCallback{DiffEqCallbacks.var"#109#113"{Vector{Float32}}, SciMLSensitivity.TrackedAffect{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, ComponentVector{Float32, CuArray{Float32, 1, CUDA.DeviceMemory}, Tuple{Axis{(layer_1 = ViewAxis(1:80, Axis(weight = ViewAxis(1:64, ShapedAxis((16, 4))), bias = ViewAxis(65:80, ShapedAxis((16, 1))))), layer_2 = ViewAxis(81:352, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(353:420, Axis(weight = ViewAxis(1:64, ShapedAxis((4, 16))), bias = ViewAxis(65:68, ShapedAxis((4, 1))))))}}}, DiffEqCallbacks.var"#111#115"{typeof(affect!)}, Nothing, Int64}, DiffEqCallbacks.var"#112#116"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, typeof(affect!)}, typeof(SciMLBase.FINALIZE_DEFAULT)}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}, dgdu::Function, dgdp::Nothing, loss_ref::Base.RefValue{Int64}, terminated::Bool)
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/PstNN/src/callback_tracking.jl:219

so, the ZygoteVJP() dosen't support Discretecallback on GPU, right?

ChrisRackauckas commented 2 months ago

Yes it needs to be ReverseDiffVJP or EnzymeVJP. For GPU then, it would need to be EnzymeVJP. I don't quite know how much coverage Enzyme has on GPU now but it should be getting close @wsmoses @avik-pal ? It would be good to have an example to work through with this.

wsmoses commented 2 months ago

There are some minor things like https://github.com/LuxDL/LuxLib.jl/issues/148. And https://github.com/JuliaGPU/CUDA.jl/pull/2471 needing to get implemented and merged, respectively — but otherwise for cuda things should generally work

avik-pal commented 2 months ago

Just as a sidenote, even once https://github.com/LuxDL/LuxLib.jl/issues/148 is merged, I need https://github.com/LuxDL/Lux.jl/pull/744 (at least 2-3 weeks) to be merged before the LuxLib fixes are available to end-users