Open yunan-l opened 3 months ago
Yes it needs to be ReverseDiffVJP or EnzymeVJP. For GPU then, it would need to be EnzymeVJP. I don't quite know how much coverage Enzyme has on GPU now but it should be getting close @wsmoses @avik-pal ? It would be good to have an example to work through with this.
There are some minor things like https://github.com/LuxDL/LuxLib.jl/issues/148. And https://github.com/JuliaGPU/CUDA.jl/pull/2471 needing to get implemented and merged, respectively — but otherwise for cuda things should generally work
Just as a sidenote, even once https://github.com/LuxDL/LuxLib.jl/issues/148 is merged, I need https://github.com/LuxDL/Lux.jl/pull/744 (at least 2-3 weeks) to be merged before the LuxLib fixes are available to end-users
Hi, I tried to train a Neural ODE with Discretecallback with
sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP(), checkpointing=true)
on GPU, but said:so, the ZygoteVJP() dosen't support Discretecallback on GPU, right?