avik-pal / RegNeuralDE.jl

Official Implementation of "Opening the Blackbox: Accelerating Neural Differential Equations by Regularizing Internal Solver Heuristics" (ICML 2021)
http://proceedings.mlr.press/v139/pal21a.html
MIT License
27 stars 4 forks source link

GPU Compilation for Tracker Backward Pass through savevals #42

Open avik-pal opened 3 years ago

avik-pal commented 3 years ago

Backward Pass on GPUs currently fail due to

ERROR: InvalidIRError: compiling kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceArray{Float32,2,CUDA.AS.Global}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(Tracker.partial),Tuple{CUDA.CuRefValue{typeof(DiffEqBase.calculate_residuals)},Base.Broadcast.Extruded{CuDeviceArray{Float32,2,CUDA.AS.Global},Tuple{Bool,Bool},Tuple{Int64,Int64}},Int64,Base.Broadcast.Extruded{CuDeviceArray{Float32,2,CUDA.AS.Global},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CuDeviceArray{Float32,2,CUDA.AS.Global},Tuple{Bool,Bool},Tuple{Int64,Int64}},Base.Broadcast.Extruded{CuDeviceArray{Float32,2,CUDA.AS.Global},Tuple{Bool,Bool},Tuple{Int64,Int64}},Float32,Float32,CUDA.CuRefValue{typeof(DiffEqBase.ODE_DEFAULT_NORM)},Float32}}, Int64) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to partial(f::F, Δ, i, args::Vararg{Any,N}) where {F, N} in Tracker at /mnt/research/Tracker/src/lib/array.jl:546)
Stacktrace:
 [1] _broadcast_getindex_evalf at broadcast.jl:648
 [2] _broadcast_getindex at broadcast.jl:621
 [3] getindex at broadcast.jl:575
 [4] broadcast_kernel at /home/avikpal/.julia/packages/GPUArrays/uaFZh/src/host/broadcast.jl:62

I was able to track that it happens when partial is called with f = DiffEqBase.calculate_residuals

avik-pal commented 3 years ago

Using the patch in https://github.com/avik-pal/DiffEqBase.jl/tree/ap/fix_gpu_regnode GPU compilation is fixed. It has to do with inlining the functions.

ChrisRackauckas commented 3 years ago

Why would inlining matter?

avik-pal commented 3 years ago

I tried looking into it a bit more but can't figure out the exact reason (couldn't find any similar issue as well). https://github.com/avik-pal/DiffEqBase.jl/commit/f1bf9927c3026bf782d8ff67297b8428efbdf503 is the patch that is needed to make it work.