SciML / NeuralPDE.jl

Physics-Informed Neural Networks (PINN) Solvers of (Partial) Differential Equations for Scientific Machine Learning (SciML) accelerated simulation
https://docs.sciml.ai/NeuralPDE/stable/
Other
990 stars 199 forks source link

CUDA with NNODE #556

Open ChrisRackauckas opened 2 years ago

ChrisRackauckas commented 2 years ago

When trying to setup CUDA with NNODE it seems I hit some Zygote issues with the adjoints of Array. @DhairyaLGandhi could I get some help on this?

using NeuralPDE, OrdinaryDiffEq, DiffEqFlux, Flux, OptimizationPolyalgorithms, CUDA
CUDA.allowscalar(false)

function f(u, p, t)
    [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]]
end

p = Float32[1.5, 1.0, 3.0, 1.0]
u0 = Float32[1.0, 1.0]
prob_oop = ODEProblem{false}(f, u0, (0f0, 3f0), p)
true_sol = solve(prob_oop, Tsit5())

prob_oop_cu = ODEProblem{false}(f, cu(u0), (0f0, 3f0), p)

N = 512
chain = FastChain(FastDense(1, N, relu), FastDense(N, N, relu), FastDense(N, N, relu), FastDense(N, length(u0)))
opt = ADAM(0.01)
θ = cu(DiffEqFlux.initial_params(chain))
alg = NeuralPDE.NNODE(chain, opt, θ; strategy = StochasticTraining(100))
sol = solve(prob_oop_cu, alg, verbose=true, maxiters=300)
julia> sol = solve(prob_oop_cu, alg, verbose=true, maxiters=300)
ERROR: MethodError: no method matching dot(::Int64, ::CuPtr{Float64}, ::Int64, ::Ptr{Float64}, ::Int64)
Closest candidates are:
  dot(::Integer, ::Union{Ptr{Float64}, AbstractArray{Float64}}, ::Integer, ::Union{Ptr{Float64}, AbstractArray{Float64}}, ::Integer) at C:\Users\accou\.julia\juliaup\julia-1.8.0-rc1+0~x64\share\julia\stdlib\v1.8\LinearAlgebra\src\blas.jl:338
  dot(::Integer, ::Union{Ptr{Float32}, AbstractArray{Float32}}, ::Integer, ::Union{Ptr{Float32}, AbstractArray{Float32}}, ::Integer) at C:\Users\accou\.julia\juliaup\julia-1.8.0-rc1+0~x64\share\julia\stdlib\v1.8\LinearAlgebra\src\blas.jl:338
Stacktrace:
  [1] dot(x::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, y::Matrix{Float64})
    @ LinearAlgebra.BLAS C:\Users\accou\.julia\juliaup\julia-1.8.0-rc1+0~x64\share\julia\stdlib\v1.8\LinearAlgebra\src\blas.jl:389
  [2] dot(x::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, y::Matrix{Float64})
    @ LinearAlgebra C:\Users\accou\.julia\juliaup\julia-1.8.0-rc1+0~x64\share\julia\stdlib\v1.8\LinearAlgebra\src\matmul.jl:14
  [3] (::ChainRules.var"#1494#1498"{Matrix{Float64}, CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, Float64})()
    @ ChainRules C:\Users\accou\.julia\packages\ChainRules\nu2G0\src\rulesets\Base\arraymath.jl:380
  [4] unthunk
    @ C:\Users\accou\.julia\packages\ChainRulesCore\16PWJ\src\tangent_types\thunks.jl:195 [inlined]
  [5] wrap_chainrules_output
    @ C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\chainrules.jl:104 [inlined]
  [6] map
    @ .\tuple.jl:223 [inlined]
  [7] wrap_chainrules_output
    @ C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\chainrules.jl:105 [inlined]
  [8] ZBack
    @ C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\chainrules.jl:205 [inlined]
  [9] (::Zygote.var"#3829#back#1024"{Zygote.ZBack{ChainRules.var"#slash_pullback_scalar#1495"{CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, Float64}}})(Δ::Matrix{Float64})
    @ Zygote C:\Users\accou\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [10] Pullback
    @ C:\Users\accou\.julia\dev\NeuralPDE\src\ode_solve.jl:163 [inlined]
 [11] (::typeof(∂(ode_dfdx)))(Δ::Matrix{Float64})
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [12] Pullback
    @ c:\Users\accou\.julia\dev\NeuralPDE\src\ode_solve.jl:191 [inlined]
 [13] (::typeof(∂(inner_loss)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [14] Pullback
    @ C:\Users\accou\.julia\dev\NeuralPDE\src\ode_solve.jl:217 [inlined]
 [15] (::typeof(∂(λ)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [16] #208
    @ C:\Users\accou\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:207 [inlined]
 [17] #1750#back
    @ C:\Users\accou\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [18] Pullback
    @ C:\Users\accou\.julia\dev\SciMLBase\src\scimlfunctions.jl:2887 [inlined]
 [19] (::typeof(∂(λ)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [20] #208
    @ C:\Users\accou\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:207 [inlined]
 [21] #1750#back
    @ C:\Users\accou\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [22] Pullback
    @ C:\Users\accou\.julia\dev\Optimization\src\function\zygote.jl:30 [inlined]
 [23] (::typeof(∂(λ)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [24] #208
    @ C:\Users\accou\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:207 [inlined]
 [25] #1750#back
    @ C:\Users\accou\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
 [26] Pullback
    @ C:\Users\accou\.julia\dev\Optimization\src\function\zygote.jl:32 [inlined]
 [27] (::typeof(∂(λ)))(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [28] (::Zygote.var"#52#53"{typeof(∂(λ))})(Δ::Float64)
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:41
 [29] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\accou\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:76
 [30] (::Optimization.var"#117#127"{Optimization.var"#116#126"{OptimizationFunction{true, Optimization.AutoZygote, NeuralPDE.var"#loss#162"{StochasticTraining, NeuralPDE.ODEPhi{FastChain{Tuple{FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}}}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Bool, Tuple{Float32, Float32}, Vector{Float32}, Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, SciMLBase.NullParameters}})(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Optimization C:\Users\accou\.julia\dev\Optimization\src\function\zygote.jl:32
 [31] macro expansion
    @ C:\Users\accou\.julia\packages\OptimizationFlux\cpWyO\src\OptimizationFlux.jl:32 [inlined]
 [32] macro expansion
    @ C:\Users\accou\.julia\dev\Optimization\src\utils.jl:35 [inlined]
 [33] __solve(prob::OptimizationProblem{true, OptimizationFunction{true, Optimization.AutoZygote, NeuralPDE.var"#loss#162"{StochasticTraining, NeuralPDE.ODEPhi{FastChain{Tuple{FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}}}, Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Bool, Tuple{Float32, Float32}, Vector{Float32}, Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, SciMLBase.NullParameters, Nothing, Nothing, Nothing, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, opt::ADAM, data::Base.Iterators.Cycle{Tuple{Optimization.NullData}}; maxiters::Int64, callback::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OptimizationFlux C:\Users\accou\.julia\packages\OptimizationFlux\cpWyO\src\OptimizationFlux.jl:30
 [34] #solve#494
    @ C:\Users\accou\.julia\dev\SciMLBase\src\solve.jl:71 [inlined]
 [35] __solve(::ODEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ::NNODE{FastChain{Tuple{FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}}}, ADAM, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, StochasticTraining}; dt::Nothing, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64)
    @ NeuralPDE c:\Users\accou\.julia\dev\NeuralPDE\src\ode_solve.jl:351
 [36] #solve_call#28
    @ C:\Users\accou\.julia\dev\DiffEqBase\src\solve.jl:429 [inlined]
 [37] solve_up(prob::ODEProblem{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}, false, Vector{Float32}, ODEFunction{false, typeof(f), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, p::Vector{Float32}, args::NNODE{FastChain{Tuple{FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(relu), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}, FastDense{typeof(identity), DiffEqFlux.var"#initial_params#107"{Vector{Float32}}, Nothing}}}, ADAM, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Nothing, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, StochasticTraining}; kwargs::Base.Pairs{Symbol, Integer, Tuple{Symbol, Symbol}, NamedTuple{(:verbose, :maxiters), Tuple{Bool, Int64}}})
    @ DiffEqBase C:\Users\accou\.julia\dev\DiffEqBase\src\solve.jl:767
 [38] #solve#33
    @ C:\Users\accou\.julia\dev\DiffEqBase\src\solve.jl:752 [inlined]
 [39] top-level scope
    @ REPL[2]:1
using Plots
plot(sol)
plot!(true_sol)
DhairyaLGandhi commented 2 years ago

Seems like f outputs Arrays. That would be the function Zygote sees as I understand it. This would need to be on the GPU. I will try to get an MWE rolling on my end, but it seems like the output is on the CPU. [9] suggests that. I also think p should be on the GPU for consistency.

ChrisRackauckas commented 2 years ago

That accumulation happens on the CPU: https://github.com/SciML/NeuralPDE.jl/blob/master/src/ode_solve.jl#L197-L201

ChrisRackauckas commented 2 years ago

It looks like this is another instantiation of https://github.com/SciML/NeuralPDE.jl/issues/533, because I did chain = FastChain(FastDense(1, N, relu), FastDense(N, N, relu), FastDense(N, N, relu), FastDense(N, length(u0))) instead of chain = FastChain(FastDense(1, N, relu), FastDense(N, N, relu), FastDense(N, N, relu), FastDense(N, length(u0))) |> gpu, and the types in the NN matter even though they aren't used.

ChrisRackauckas commented 2 years ago

Actually nope, this case was using FastChain, missed that 😅 . So this case still isn't clear.

sathvikbhagavan commented 7 months ago

Is this still an issue?

ChrisRackauckas commented 7 months ago

Needs tests

ChrisRackauckas commented 7 months ago

@sathvikbhagavan do you have a code for this which demonstrates it? I don't think I've seen NNODE CUDA at all.

sathvikbhagavan commented 7 months ago

Actually I thought as PhysicsInformedNN works with GPUs, this wouldn't be a problem, but apparently not 😅

MWE:

This is using a branch https://github.com/SciML/NeuralPDE.jl/tree/sb/try_gpu_nnode (had to do some fixes)

using Random, NeuralPDE
using OrdinaryDiffEq
using Lux, OptimizationOptimisers
using LuxCUDA, ComponentArrays

rng = Random.default_rng()
Random.seed!(100)
const gpud = Lux.gpu_device()

linear = (u, p, t) -> cos(2pi * t)
tspan = (0.0, 1.0)
u0 = 0.0
prob = ODEProblem(linear, u0, tspan)
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
ps = Lux.setup(rng, luxchain)[1] |> ComponentArray |> gpud .|> Float64
opt = OptimizationOptimisers.Adam(0.1, (0.9, 0.95))

sol = solve(prob, NNODE(luxchain, opt, ps; strategy = GridTraining(0.01), device = gpud), verbose = true, maxiters = 200)

sol = solve(prob, NNODE(luxchain, opt, ps; strategy = StochasticTraining(100), device = gpud), verbose = true, maxiters = 200)
julia> sol = solve(prob, NNODE(luxchain, opt, ps; strategy = GridTraining(0.01), device = gpud), verbose = true, maxiters = 200)
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] errorscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
  [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
  [4] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
  [5] getindex(A::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/Hd5Sk/src/host/indexing.jl:48
  [6] scalar_getindex
    @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/indexing.jl:34 [inlined]
  [7] _getindex
    @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/indexing.jl:17 [inlined]
  [8] getindex
    @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/indexing.jl:15 [inlined]
  [9] _generic_matmatmul!(C::Matrix{…}, tA::Char, tB::Char, A::CuArray{…}, B::StepRangeLen{…}, _add::LinearAlgebra.MulAddMul{…})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:814
 [10] generic_matmatmul!(C::Matrix{…}, tA::Char, tB::Char, A::CuArray{…}, B::StepRangeLen{…}, _add::LinearAlgebra.MulAddMul{…})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:783
 [11] mul!
    @ ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:263 [inlined]
 [12] mul!
    @ ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:237 [inlined]
 [13] *
    @ ~/.julia/juliaup/julia-1.10.2+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:106 [inlined]
 [14] rrule
    @ ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/arraymath.jl:40 [inlined]
 [15] rrule
    @ ~/.julia/packages/ChainRulesCore/zgT0R/src/rules.jl:134 [inlined]
 [16] chain_rrule
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:223 [inlined]
 [17] macro expansion
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0 [inlined]
 [18] _pullback(::Zygote.Context{false}, ::typeof(*), ::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer}, ::LinearAlgebra.Adjoint{Float64, StepRangeLen{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:81
 [19] Dense
    @ ~/.julia/packages/Lux/bejlA/src/layers/basic.jl:230 [inlined]
 [20] apply
    @ ~/.julia/packages/LuxCore/8lRV2/src/LuxCore.jl:180 [inlined]
 [21] _pullback(::Zygote.Context{…}, ::typeof(LuxCore.apply), ::Dense{…}, ::LinearAlgebra.Adjoint{…}, ::ComponentVector{…}, ::@NamedTuple{})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [22] macro expansion
    @ ~/.julia/packages/Lux/bejlA/src/layers/containers.jl:0 [inlined]
 [23] applychain
    @ ~/.julia/packages/Lux/bejlA/src/layers/containers.jl:479 [inlined]
 [24] _pullback(::Zygote.Context{…}, ::typeof(Lux.applychain), ::@NamedTuple{…}, ::LinearAlgebra.Adjoint{…}, ::ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [25] Chain
    @ ~/.julia/packages/Lux/bejlA/src/layers/containers.jl:477 [inlined]
 [26] _pullback(::Zygote.Context{…}, ::Chain{…}, ::LinearAlgebra.Adjoint{…}, ::ComponentVector{…}, ::@NamedTuple{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [27] ODEPhi
    @ ~/NeuralPDE.jl/src/ode_solve.jl:132 [inlined]
 [28] _pullback(::Zygote.Context{…}, ::NeuralPDE.ODEPhi{…}, ::StepRangeLen{…}, ::ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [29] inner_loss
    @ ~/NeuralPDE.jl/src/ode_solve.jl:203 [inlined]
 [30] _pullback(::Zygote.Context{…}, ::typeof(NeuralPDE.inner_loss), ::NeuralPDE.ODEPhi{…}, ::ODEFunction{…}, ::Bool, ::StepRangeLen{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters, ::Bool)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [31] loss
    @ ~/NeuralPDE.jl/src/ode_solve.jl:254 [inlined]
 [32] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#loss#190"{…}, ::ComponentVector{…}, ::NeuralPDE.ODEPhi{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [33] total_loss
    @ ~/NeuralPDE.jl/src/ode_solve.jl:423 [inlined]
 [34] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#total_loss#547"{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [35] _apply
    @ ./boot.jl:838 [inlined]
 [36] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [37] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [38] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/NjslX/src/scimlfunctions.jl:3649 [inlined]
 [39] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [40] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [41] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [42] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [43] #37
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:90 [inlined]
 [44] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#37#55"{…}, args::ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [45] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [46] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [47] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [48] #39
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93 [inlined]
 [49] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var"#39#57"{Tuple{}, OptimizationZygoteExt.var"#37#55"{…}}, args::ComponentVector{Float64, CuArray{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [50] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
 [51] pullback
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined]
 [52] gradient(f::Function, args::ComponentVector{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
 [53] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{…}})(::ComponentVector{Float64, CuArray{…}, Tuple{…}}, ::ComponentVector{Float64, CuArray{…}, Tuple{…}})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93
 [54] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [55] macro expansion
    @ ~/.julia/packages/Optimization/5DEdF/src/utils.jl:32 [inlined]
 [56] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [57] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:180
 [58] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:96
 [59] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Nothing, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64, tstops::Nothing)
    @ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:464
 [60] __solve
    @ ~/NeuralPDE.jl/src/ode_solve.jl:344 [inlined]
 [61] solve_call(_prob::ODEProblem{…}, args::NNODE{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:612
 [62] solve_call
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:569 [inlined]
 [63] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1080 [inlined]
 [64] solve_up
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1066 [inlined]
 [65] #solve#51
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1003 [inlined]
 [66] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.
julia> sol = solve(prob, NNODE(luxchain, opt, ps; strategy = StochasticTraining(100), device = gpud), verbose = true, maxiters = 200)
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::CUDA.CuKernelContext, ::CuDeviceMatrix{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Matrix{Float64}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{Matrix{Float64}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Extruded{CuDeviceMatrix{Float64, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}} which is not isbits.
    .1 is of type Base.Broadcast.Extruded{Matrix{Float64}, Tuple{Bool, Bool}, Tuple{Int64, Int64}} which is not isbits.
      .x is of type Matrix{Float64} which is not isbits.

Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/validation.jl:92
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:123 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/RsWnF/src/TimerOutput.jl:253 [inlined]
  [4] 
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:121
  [5] codegen
    @ ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:110 [inlined]
  [6] compile(target::Symbol, job::GPUCompiler.CompilerJob; libraries::Bool, toplevel::Bool, optimize::Bool, cleanup::Bool, strip::Bool, validate::Bool, only_entry::Bool)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:106
  [7] compile
    @ ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:98 [inlined]
  [8] #1072
    @ ~/.julia/packages/CUDA/htRwP/src/compiler/compilation.jl:247 [inlined]
  [9] JuliaContext(f::CUDA.var"#1072#1075"{GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams}})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47
 [10] compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/htRwP/src/compiler/compilation.jl:246
 [11] actual_compilation(cache::Dict{…}, src::Core.MethodInstance, world::UInt64, cfg::GPUCompiler.CompilerConfig{…}, compiler::typeof(CUDA.compile), linker::typeof(CUDA.link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/execution.jl:125
 [12] cached_compilation(cache::Dict{…}, src::Core.MethodInstance, cfg::GPUCompiler.CompilerConfig{…}, compiler::Function, linker::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/execution.jl:103
 [13] macro expansion
    @ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:367 [inlined]
 [14] macro expansion
    @ ./lock.jl:267 [inlined]
 [15] cufunction(f::GPUArrays.var"#broadcast_kernel#38", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{…}, Base.Broadcast.Broadcasted{…}, Int64}}; kwargs::@Kwargs{})
    @ CUDA ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:362
 [16] cufunction
    @ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:359 [inlined]
 [17] macro expansion
    @ ~/.julia/packages/CUDA/htRwP/src/compiler/execution.jl:112 [inlined]
 [18] #launch_heuristic#1122
    @ ~/.julia/packages/CUDA/htRwP/src/gpuarrays.jl:17 [inlined]
 [19] launch_heuristic
    @ ~/.julia/packages/CUDA/htRwP/src/gpuarrays.jl:15 [inlined]
 [20] _copyto!
    @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/broadcast.jl:56 [inlined]
 [21] copyto!
    @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/broadcast.jl:37 [inlined]
 [22] copy
    @ ~/.julia/packages/GPUArrays/Hd5Sk/src/host/broadcast.jl:28 [inlined]
 [23] materialize
    @ ./broadcast.jl:903 [inlined]
 [24] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/broadcast.jl:92 [inlined]
 [25] _pullback(__context__::Zygote.Context{false}, 590::typeof(Base.Broadcast.broadcasted), 591::typeof(*), x::Matrix{Float64}, y::CuArray{Float64, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67
 [26] ODEPhi
    @ ~/NeuralPDE.jl/src/ode_solve.jl:135 [inlined]
 [27] _pullback(::Zygote.Context{false}, ::NeuralPDE.ODEPhi{Chain{…}, Float64, Float64, @NamedTuple{…}}, ::Vector{Float64}, ::ComponentVector{Float64, CuArray{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [28] inner_loss
    @ ~/NeuralPDE.jl/src/ode_solve.jl:203 [inlined]
 [29] _pullback(::Zygote.Context{…}, ::typeof(NeuralPDE.inner_loss), ::NeuralPDE.ODEPhi{…}, ::ODEFunction{…}, ::Bool, ::Vector{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters, ::Bool)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [30] loss
    @ ~/NeuralPDE.jl/src/ode_solve.jl:269 [inlined]
 [31] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#loss#194"{…}, ::ComponentVector{…}, ::NeuralPDE.ODEPhi{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [32] total_loss
    @ ~/NeuralPDE.jl/src/ode_solve.jl:423 [inlined]
 [33] _pullback(::Zygote.Context{…}, ::NeuralPDE.var"#total_loss#547"{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [34] _apply
    @ ./boot.jl:838 [inlined]
 [35] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [36] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [37] OptimizationFunction
    @ ~/.julia/packages/SciMLBase/NjslX/src/scimlfunctions.jl:3649 [inlined]
 [38] _pullback(::Zygote.Context{…}, ::OptimizationFunction{…}, ::ComponentVector{…}, ::SciMLBase.NullParameters)
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [39] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [40] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [41] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [42] #37
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:90 [inlined]
 [43] _pullback(ctx::Zygote.Context{…}, f::OptimizationZygoteExt.var"#37#55"{…}, args::ComponentVector{…})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [44] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [45] adjoint
    @ ~/.julia/packages/Zygote/jxHJc/src/lib/lib.jl:203 [inlined]
 [46] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [47] #39
    @ ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93 [inlined]
 [48] _pullback(ctx::Zygote.Context{false}, f::OptimizationZygoteExt.var"#39#57"{Tuple{}, OptimizationZygoteExt.var"#37#55"{…}}, args::ComponentVector{Float64, CuArray{…}, Tuple{…}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
 [49] pullback(f::Function, cx::Zygote.Context{false}, args::ComponentVector{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:90
 [50] pullback
    @ ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:88 [inlined]
 [51] gradient(f::Function, args::ComponentVector{Float64, CuArray{Float64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}})
    @ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:147
 [52] (::OptimizationZygoteExt.var"#38#56"{OptimizationZygoteExt.var"#37#55"{…}})(::ComponentVector{Float64, CuArray{…}, Tuple{…}}, ::ComponentVector{Float64, CuArray{…}, Tuple{…}})
    @ OptimizationZygoteExt ~/.julia/packages/OptimizationBase/rRpJs/ext/OptimizationZygoteExt.jl:93
 [53] macro expansion
    @ ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:68 [inlined]
 [54] macro expansion
    @ ~/.julia/packages/Optimization/5DEdF/src/utils.jl:32 [inlined]
 [55] __solve(cache::OptimizationCache{…})
    @ OptimizationOptimisers ~/.julia/packages/OptimizationOptimisers/AOkbT/src/OptimizationOptimisers.jl:66
 [56] solve!(cache::OptimizationCache{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:180
 [57] solve(::OptimizationProblem{…}, ::Adam; kwargs::@Kwargs{…})
    @ SciMLBase ~/.julia/packages/SciMLBase/NjslX/src/solve.jl:96
 [58] __solve(::ODEProblem{…}, ::NNODE{…}; dt::Nothing, timeseries_errors::Bool, save_everystep::Bool, adaptive::Bool, abstol::Float32, reltol::Float32, verbose::Bool, saveat::Nothing, maxiters::Int64, tstops::Nothing)
    @ NeuralPDE ~/NeuralPDE.jl/src/ode_solve.jl:464
 [59] __solve
    @ ~/NeuralPDE.jl/src/ode_solve.jl:344 [inlined]
 [60] solve_call(_prob::ODEProblem{…}, args::NNODE{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:612
 [61] solve_call
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:569 [inlined]
 [62] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1080 [inlined]
 [63] solve_up
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1066 [inlined]
 [64] #solve#51
    @ ~/.julia/packages/DiffEqBase/O8cUq/src/solve.jl:1003 [inlined]
 [65] top-level scope
    @ REPL[20]:1
Some type information was truncated. Use `show(err)` to see complete types.

So, yeah, we still have to fix it.