SciML / NeuralPDE.jl

Physics-Informed Neural Networks (PINN) Solvers of (Partial) Differential Equations for Scientific Machine Learning (SciML) accelerated simulation
923 stars 196 forks source link

GPU Compatibility Issue: Compilation Error with Complex-Valued Data in LuxCUDA Broadcasting Kernel #844

Open RomanSahakyan03 opened 2 months ago

RomanSahakyan03 commented 2 months ago

Bug Description


When attempting to solve a neural network optimization problem on a GPU using Lux and LuxCUDA packages in Julia, a GPU compilation error occurs.

Steps to Reproduce

Expected Behavior

The optimization problem should be solved without errors, utilizing GPU acceleration provided by the LuxCUDA package. Observed Behavior

The GPU compilation of MethodInstance for broadcasting fails with a KernelError, specifically mentioning a non-bitstype argument issue. Code Snippet

using Lux, LuxCUDA, ComponentArrays, Random

# Define neural network architecture
const gpud = gpu_device()
rng = Random.default_rng()
Random.seed!(rng, 0)

inner = 16
chain = Chain(Dense(1, inner, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)),
              Dense(inner, inner, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)), 
              Dense(inner, inner, tanh; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)), 
              Dense(inner, 9; init_weight = (rng, a...) -> kaiming_normal(rng, ComplexF64, a...)))
ps = Lux.setup(rng, chain)[1]
ps = ps |> ComponentArray |> gpud .|> ComplexF64
ComponentVector{ComplexF64, CuArray{ComplexF64, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(layer_1 = ViewAxis(1:32, Axis(weight = ViewAxis(1:16, ShapedAxis((16, 1))), bias = ViewAxis(17:32, ShapedAxis((16, 1))))), layer_2 = ViewAxis(33:304, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_3 = ViewAxis(305:576, Axis(weight = ViewAxis(1:256, ShapedAxis((16, 16))), bias = ViewAxis(257:272, ShapedAxis((16, 1))))), layer_4 = ViewAxis(577:729, Axis(weight = ViewAxis(1:144, ShapedAxis((9, 16))), bias = ViewAxis(145:153, ShapedAxis((9, 1))))))}}}(layer_1 = (weight = ComplexF64[0.9429705142974854 + 0.1339227557182312im; 1.5250688791275024 + 0.12390123307704926im; … ; 0.5579001307487488 - 0.35648801922798157im; 0.9500746726989746 - 0.20232219994068146im;;], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]), layer_2 = (weight = ComplexF64[0.059399593621492386 + 0.025146976113319397im 0.1949768215417862 + 0.24093444645404816im … 0.02936505898833275 - 0.1352502554655075im 0.5359262824058533 - 0.491843044757843im; -0.07353769242763519 + 0.050222259014844894im -0.23228807747364044 + 0.01972302421927452im … -0.1863224059343338 + 0.030169149860739708im -0.2124786078929901 - 0.04057123884558678im; … ; 0.04917571693658829 + 0.06531829386949539im -0.26813575625419617 - 0.24699832499027252im … -0.005230876617133617 + 0.021611899137496948im -0.1623590737581253 + 0.14148622751235962im; 0.3998381197452545 - 0.09549206495285034im 0.01471997331827879 - 0.27302247285842896im … -0.09034821391105652 + 0.11481619626283646im -0.5329245924949646 + 0.3032892346382141im], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]), layer_3 = (weight = ComplexF64[0.18369489908218384 - 0.17931848764419556im -0.4184981882572174 + 0.15965186059474945im … 0.22417707741260529 - 0.22444866597652435im 0.3134605288505554 - 0.005288226064294577im; 0.5319058299064636 - 0.12305065989494324im 0.02565431408584118 - 0.02762402780354023im … -0.11335651576519012 + 0.2669583559036255im -0.0010091445874422789 - 0.053010717034339905im; … ; -0.3982292413711548 - 0.006003747694194317im -0.29939648509025574 + 0.17847703397274017im … -0.012875470332801342 - 0.3082279860973358im -0.5564959049224854 + 0.09695551544427872im; 0.007936030626296997 - 0.2567330002784729im 0.11311032623052597 + 0.1972206085920334im … 0.02036339044570923 - 0.14611773192882538im -0.024891655892133713 + 0.17227661609649658im], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]), layer_4 = (weight = ComplexF64[-0.046117015182971954 + 0.09711457043886185im 0.5025700330734253 + 0.05446240305900574im … 0.2066519558429718 - 0.01681804470717907im 0.15362724661827087 + 0.24123860895633698im; -0.11880122870206833 - 0.2789801061153412im -0.08881326764822006 + 0.14416104555130005im … 0.34971800446510315 + 0.02146727591753006im 0.10826357454061508 - 0.021323617547750473im; … ; -0.15876266360282898 - 0.6521790027618408im 0.04549488052725792 + 0.018977994099259377im … -0.04921087995171547 + 0.2560370862483978im -0.23153409361839294 - 0.29215309023857117im; -0.13698288798332214 - 0.28654682636260986im 0.03768850117921829 + 0.06687548756599426im … -0.4321778416633606 + 0.4295826852321625im -0.0034131575375795364 - 0.45368692278862im], bias = ComplexF64[0.0 + 0.0im; 0.0 + 0.0im; … ; 0.0 + 0.0im; 0.0 + 0.0im;;]))
opt = Adam(0.01)
alg = NNODE(chain, opt, ps; strategy = StochasticTraining(300,30000))
SciMLBase.allowscomplex(::NNODE) = true

# Attempt to solve the problem
sol = solve(problem, alg, verbose = true, maxiters = 1000, saveat = 0.001)
ERROR: GPU compilation of MethodInstance for (::GPUArrays.var"#broadcast_kernel#38")(::CUDA.CuKernelContext, ::CuDeviceVector{…}, ::Base.Broadcast.Broadcasted{…}, ::Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Tuple{Base.OneTo{Int64}}, typeof(+), Tuple{Base.Broadcast.Extruded{Vector{ComplexF64}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Float64, Float64}}, Base.Broadcast.Extruded{CuDeviceVector{ComplexF64, 1}, Tuple{Bool}, Tuple{Int64}}}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{Vector{ComplexF64}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.Mem.DeviceBuffer}, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Nothing, typeof(-), Tuple{Float64, Float64}}, Base.Broadcast.Extruded{CuDeviceVector{ComplexF64, 1}, Tuple{Bool}, Tuple{Int64}}}}} which is not isbits.
    .1 is of type Base.Broadcast.Extruded{Vector{ComplexF64}, Tuple{Bool}, Tuple{Int64}} which is not isbits.
      .x is of type Vector{ComplexF64} which is not isbits.

Additional Information

IromainI commented 2 months ago

I am also trying to solve the problem of optimization a neural network on a GPU (LuxCUDA in Julia) and I also get the same GPU compilation error

QuSimulations commented 2 months ago

I have the same bug with the GPU (LuxCUDA in Julia) and encounter the same GPU compilation error.

RomanSahakyan03 commented 2 months ago

@sathvikbhagavan how can I assist you?

RomanSahakyan03 commented 2 months ago

@sathvikbhagavan ?

sathvikbhagavan commented 2 months ago

@RomanSahakyan03, apologies for the late reply. I will try to finish it up by this weekend.

RomanSahakyan03 commented 1 month ago

@sathvikbhagavan it's ok. Thank for your efforts! If you need assist. I can help

RomanSahakyan03 commented 2 weeks ago

@sathvikbhagavan what about now? Did you finish it?

sathvikbhagavan commented 1 week ago

Hi @RomanSahakyan03, I have a draft PR #866 for fixing this, but currently running into some issues. Hopefully would get resolved.