Open ChrisRackauckas opened 2 years ago
Seems like f
outputs Arrays. That would be the function Zygote sees as I understand it. This would need to be on the GPU. I will try to get an MWE rolling on my end, but it seems like the output is on the CPU. [9]
suggests that. I also think p
should be on the GPU for consistency.
That accumulation happens on the CPU: https://github.com/SciML/NeuralPDE.jl/blob/master/src/ode_solve.jl#L197-L201
It looks like this is another instantiation of https://github.com/SciML/NeuralPDE.jl/issues/533, because I did chain = FastChain(FastDense(1, N, relu), FastDense(N, N, relu), FastDense(N, N, relu), FastDense(N, length(u0)))
instead of chain = FastChain(FastDense(1, N, relu), FastDense(N, N, relu), FastDense(N, N, relu), FastDense(N, length(u0))) |> gpu
, and the types in the NN matter even though they aren't used.
Actually nope, this case was using FastChain
, missed that 😅 . So this case still isn't clear.
Is this still an issue?
Needs tests
@sathvikbhagavan do you have a code for this which demonstrates it? I don't think I've seen NNODE CUDA at all.
Actually I thought as PhysicsInformedNN
works with GPUs, this wouldn't be a problem, but apparently not 😅
MWE:
This is using a branch https://github.com/SciML/NeuralPDE.jl/tree/sb/try_gpu_nnode (had to do some fixes)
using Random, NeuralPDE
using OrdinaryDiffEq
using Lux, OptimizationOptimisers
using LuxCUDA, ComponentArrays
rng = Random.default_rng()
Random.seed!(100)
const gpud = Lux.gpu_device()
linear = (u, p, t) -> cos(2pi * t)
tspan = (0.0, 1.0)
u0 = 0.0
prob = ODEProblem(linear, u0, tspan)
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
ps = Lux.setup(rng, luxchain)[1] |> ComponentArray |> gpud .|> Float64
opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))
sol = solve(prob, NNODE(luxchain, opt, ps; strategy = GridTraining(0.01), device = gpud), verbose = true, maxiters = 200)
sol = solve(prob, NNODE(luxchain, opt, ps; strategy = StochasticTraining(100), device = gpud), verbose = true, maxiters = 200)
julia> sol = solve(prob, NNODE(luxchain, opt, ps; strategy = GridTraining(0.01), device = gpud), verbose = true, maxiters = 200)
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.
If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] errorscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
[3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
[4] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
[5] getindex(A::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
@ GPUArrays ~/.julia/packages/GPUArrays/Hd5Sk/src/host/indexing.jl:48
[6] scalar_getindex
@ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/indexing.jl:34 [inlined]
[7] _getindex
@ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/indexing.jl:17 [inlined]
[8] getindex
@ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/indexing.jl:15 [inlined]
[9] _generic_matmatmul!(C::Matrix{…}, tA::Char, tB::Char, A::CuArray{…}, B::StepRangeLen{…}, _add::LinearAlgebra.MulAddMul{…})
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:814
[10] generic_matmatmul!(C::Matrix{…}, tA::Char, tB::Char, A::CuArray{…}, B::StepRangeLen{…}, _add::LinearAlgebra.MulAddMul{…})
@ LinearAlgebra ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:783
[11] mul!
@ ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:263 [inlined]
[12] mul!
@ ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
[13] *
@ ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:106 [inlined]
[14] rrule
@ ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/arraymath.jl:40 [inlined]
[15] rrule
@ ~/.julia/packages/ChainRulesCore/zgT0R/src/rules.jl:134 [inlined]
[16] chain_rrule
@ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:223 [inlined]
[17] macro expansion
@ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [inlined]
[18] _pullback(::Zygote.Context{false}, ::typeof(*), ::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, ::LinearAlgebra.Adjoint{Float64, StepRangeLen{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:81
[19] Dense
@ ~/.julia/packages/Lux/bejlA/src/layers/basic.jl:230 [inlined]
[20] apply
@ ~/.julia/packages/LuxCore/8lRV2/src/LuxCore.jl:180 [inlined]
[21] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Dense{…}, ::LinearAlgebra.Adjoint{…}, ::ComponentVector{…}, ::@NamedTuple{})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[22] macro expansion
@ ~/.julia/packages/Lux/bejlA/src/layers/containers.jl:0 [inlined]
[23] applychain
@ ~/.julia/packages/Lux/bejlA/src/layers/containers.jl:479 [inlined]
[24] _pullback(::Zygote.Context{…}, ::typeof(Lux.applychain), ::@NamedTuple{…}, ::LinearAlgebra.Adjoint{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[25] Chain
@ ~/.julia/packages/Lux/bejlA/src/layers/containers.jl:477 [inlined]
[26] _pullback(::Zygote.Context{…}, ::Chain{…}, ::LinearAlgebra.Adjoint{…}, ::ComponentVector{…}, ::@NamedTuple{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[27] ODEPhi
@ ~/NeuralPDE.jl/src/ode_solve.jl:132 [inlined]
[28] _pullback(::Zygote.Context{…}, ::NeuralPDE.ODEPhi{…}, ::StepRangeLen{…}, ::ComponentVector{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[29] inner_loss
@ ~/NeuralPDE.jl/src/ode_solve.jl:203 [inlined]
[30] _pullback(::Zygote.Context{…}, ::typeof(NeuralPDE.inner_loss), ::NeuralPDE.ODEPhi{…}, ::ODEFunction{…}, ::Bool, ::StepRangeLen{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters, ::Bool)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[31] loss
@ ~/NeuralPDE.jl/src/ode_solve.jl:254 [inlined]
[32] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#loss#190"{…}, ::ComponentVector{…}, ::NeuralPDE.ODEPhi{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[33] total_loss
@ ~/NeuralPDE.jl/src/ode_solve.jl:423 [inlined]
[34] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#total_loss#547"{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[35] _apply
@ ./boot.jl:838 [inlined]
[36] adjoint
@ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
[37] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[38] OptimizationFunction
@ ~/.julia/packages/SciMLBase/NjslX/src/scimlfunctions.jl:3649 [inlined]
[39] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[40] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[41] adjoint
@ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
[42] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[43] #37
@ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:90 [inlined]
[44] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#37#55"{…}, args::ComponentVector{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[45] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[46] adjoint
@ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
[47] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[48] #39
@ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93 [inlined]
[49] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var"#39#57"{Tuple{}, OptimizationZygoteExt.var"#37#55"{…}}, args::ComponentVector{Float64, CuArray{…}, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[50] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
[51] pullback
@ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined]
[52] gradient(f::Function, args::ComponentVector{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
[53] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{…}})(::ComponentVector{Float64, CuArray{…}, Tuple{…}}, ::ComponentVector{Float64, CuArray{…}, Tuple{…}})
@ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93
[54] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
[55] macro expansion
@ ~/.julia/packages/Optimization/5DEdF/src/utils.jl:32 [inlined]
[56] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
[57] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:180
[58] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:96
[59] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Nothing, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64, tstops::Nothing)
@ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:464
[60] __solve
@ ~/NeuralPDE.jl/src/ode_solve.jl:344 [inlined]
[61] solve_call(_prob::ODEProblem{…}, args::NNODE{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:612
[62] solve_call
@ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:569 [inlined]
[63] #solve_up#53
@ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1080 [inlined]
[64] solve_up
@ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1066 [inlined]
[65] #solve#51
@ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1003 [inlined]
[66] top-level scope
@ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> sol = solve(prob, NNODE(luxchain, opt, ps; strategy = StochasticTraining(100), device = gpud), verbose = true, maxiters = 200)
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::CUDA.CuKernelContext, ::CuDeviceMatrix{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Matrix{Float64}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
.args is of type Tuple{Base.Broadcast.Extruded{Matrix{Float64}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
.1 is of type Base.Broadcast.Extruded{Matrix{Float64}, Tuple{Bool, Bool}, Tuple{Int64, Int64}} which is not isbits.
.x is of type Matrix{Float64} which is not isbits.
Stacktrace:
[1] check_invocation(job::GPUCompiler.CompilerJob)
@ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/validation.jl:92
[2] macro expansion
@ ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:123 [inlined]
[3] macro expansion
@ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
[4]
@ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:121
[5] codegen
@ ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:110 [inlined]
[6] compile(target::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Bool)
@ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:106
[7] compile
@ ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:98 [inlined]
[8] #1072
@ ~/.julia/packages/CUDA/htRwP/src/compiler/compilation.jl:247 [inlined]
[9] JuliaContext(f::CUDA.var"#1072#1075"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
@ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47
[10] compile(job::GPUCompiler.CompilerJob)
@ CUDA ~/.julia/packages/CUDA/htRwP/src/compiler/compilation.jl:246
[11] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
@ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/execution.jl:125
[12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
@ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/execution.jl:103
[13] macro expansion
@ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:367 [inlined]
[14] macro expansion
@ ./lock.jl:267 [inlined]
[15] cufunction(f::GPUArrays.var"#broadcast_kernel#38", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{…}, Base.Broadcast.Broadcasted{…}, Int64}}; kwargs::@Kwargs{})
@ CUDA ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:362
[16] cufunction
@ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:359 [inlined]
[17] macro expansion
@ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:112 [inlined]
[18] #launch_heuristic#1122
@ ~/.julia/packages/CUDA/htRwP/src/gpuarrays.jl:17 [inlined]
[19] launch_heuristic
@ ~/.julia/packages/CUDA/htRwP/src/gpuarrays.jl:15 [inlined]
[20] _copyto!
@ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/broadcast.jl:56 [inlined]
[21] copyto!
@ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/broadcast.jl:37 [inlined]
[22] copy
@ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/broadcast.jl:28 [inlined]
[23] materialize
@ ./broadcast.jl:903 [inlined]
[24] adjoint
@ ~/.julia/packages/Zygote/jxHJc/src/lib/broadcast.jl:92 [inlined]
[25] _pullback(__context__::Zygote.Context{false}, 590::typeof(Base.Broadcast.broadcasted), 591::typeof(*), x::Matrix{Float64}, y::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67
[26] ODEPhi
@ ~/NeuralPDE.jl/src/ode_solve.jl:135 [inlined]
[27] _pullback(::Zygote.Context{false}, ::NeuralPDE.ODEPhi{Chain{…}, Float64, Float64, @NamedTuple{…}}, ::Vector{Float64}, ::ComponentVector{Float64, CuArray{…}, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[28] inner_loss
@ ~/NeuralPDE.jl/src/ode_solve.jl:203 [inlined]
[29] _pullback(::Zygote.Context{…}, ::typeof(NeuralPDE.inner_loss), ::NeuralPDE.ODEPhi{…}, ::ODEFunction{…}, ::Bool, ::Vector{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters, ::Bool)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[30] loss
@ ~/NeuralPDE.jl/src/ode_solve.jl:269 [inlined]
[31] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#loss#194"{…}, ::ComponentVector{…}, ::NeuralPDE.ODEPhi{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[32] total_loss
@ ~/NeuralPDE.jl/src/ode_solve.jl:423 [inlined]
[33] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#total_loss#547"{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[34] _apply
@ ./boot.jl:838 [inlined]
[35] adjoint
@ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
[36] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[37] OptimizationFunction
@ ~/.julia/packages/SciMLBase/NjslX/src/scimlfunctions.jl:3649 [inlined]
[38] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[39] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[40] adjoint
@ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
[41] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[42] #37
@ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:90 [inlined]
[43] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#37#55"{…}, args::ComponentVector{…})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[44] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:838
[45] adjoint
@ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
[46] _pullback
@ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
[47] #39
@ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93 [inlined]
[48] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var"#39#57"{Tuple{}, OptimizationZygoteExt.var"#37#55"{…}}, args::ComponentVector{Float64, CuArray{…}, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[49] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
[50] pullback
@ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined]
[51] gradient(f::Function, args::ComponentVector{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
[52] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{…}})(::ComponentVector{Float64, CuArray{…}, Tuple{…}}, ::ComponentVector{Float64, CuArray{…}, Tuple{…}})
@ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93
[53] macro expansion
@ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
[54] macro expansion
@ ~/.julia/packages/Optimization/5DEdF/src/utils.jl:32 [inlined]
[55] __solve(cache::OptimizationCache{…})
@ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
[56] solve!(cache::OptimizationCache{…})
@ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:180
[57] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
@ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:96
[58] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Nothing, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64, tstops::Nothing)
@ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:464
[59] __solve
@ ~/NeuralPDE.jl/src/ode_solve.jl:344 [inlined]
[60] solve_call(_prob::ODEProblem{…}, args::NNODE{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:612
[61] solve_call
@ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:569 [inlined]
[62] #solve_up#53
@ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1080 [inlined]
[63] solve_up
@ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1066 [inlined]
[64] #solve#51
@ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1003 [inlined]
[65] top-level scope
@ REPL[20]:1
Some type information was truncated. Use `show(err)` to see complete types.
So, yeah, we still have to fix it.
When trying to setup CUDA with NNODE it seems I hit some Zygote issues with the adjoints of
Array
. @DhairyaLGandhi could I get some help on this?