Closed ccasert closed 3 years ago
The workaround of course is to just convert directly to a CuArray:
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots, CUDA, DiffEqSensitivity
CUDA.allowscalar(false) # Makes sure no slow operations are occuring
# Generate Data
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
# Make the data into a GPU-based array if the user has a GPU
ode_data = CuArray(solve(prob_trueode, Tsit5(), saveat = tsteps))
dudt2 = FastChain((x, p) -> x.^3,
FastDense(2, 50, tanh),
FastDense(50, 2))
u0 = Float32[2.0; 0.0] |> gpu
p = initial_params(dudt2) |> gpu
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
function predict_neuralode(p)
CuArray(prob_neuralode(u0,p))
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
# Callback function to observe training
list_plots = []
iter = 0
callback = function (p, l, pred; doplot = false)
global list_plots, iter
if iter == 0
list_plots = []
end
iter += 1
display(l)
# plot current prediction against data
plt = scatter(tsteps, Array(ode_data[1,:]), label = "data")
scatter!(plt, tsteps, Array(pred[1,:]), label = "prediction")
push!(list_plots, plt)
if doplot
display(plot(plt))
end
return false
end
result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, p,
ADAM(0.05), cb = callback,
maxiters = 300)
@DhairyaLGandhi @mcabbott do you know why gpu
is now going through a bunch of functor
stuff and is not AbstractArray generic? The MWE is:
using RecursiveArrayTools, CUDA
x = VectorOfArray([rand(Float32,4),rand(Float32,4)])
CuArray(x) # works
gpu(x) # fails
and it used to just call CuArray
?
What's the zygote version? We haven't done a functors release but zygote was released recently. I'm not sure if that would explain it entirely, but surely this is a bug. Is it a recent RecursiveArrayTools release?
Found the issue - it stems from the new CUDA v3.4 release series.
with CUDA v3.4.1
x = VectorOfArray([rand(Float32,4),rand(Float32,4)]);
julia> Flux.CUDA.cu(x)
ERROR: MethodError: Cannot `convert` an object of type Vector{Float32} to an object of type Float32
Closest candidates are:
convert(::Type{T}, ::ColorTypes.Gray24) where T<:Real at /home/dhairyalgandhi/.julia/packages/ColorTypes/6m8P7/src/conversions.jl:114
convert(::Type{T}, ::ColorTypes.Gray) where T<:Real at /home/dhairyalgandhi/.julia/packages/ColorTypes/6m8P7/src/conversions.jl:113
convert(::Type{T}, ::Static.StaticFloat64{N}) where {N, T<:AbstractFloat} at /home/dhairyalgandhi/.julia/packages/Static/lCOFN/src/float.jl:26
with CUDA v3.3.6
julia> x = VectorOfArray([rand(Float32,4),rand(Float32,4)]);
julia> Flux.CUDA.cu(x)
4×2 CUDA.CuArray{Float32, 2}:
0.840366 0.391562
0.682369 0.146678
0.15754 0.320327
0.0135089 0.0619746
@maleadt have we changed something in CUDA.cu
?
Thanks! Reverting to CUDA v3.3.6 allows me to run that example on GPU. However, there are still issues with other functions, e.g. normalizing flows on GPU:
using DiffEqFlux, DifferentialEquations, GalacticOptim, Distributions
nn = Chain(
Dense(1, 3, tanh),
Dense(3, 1, tanh),
) |> gpu
tspan = (0.0f0, 10.0f0)
ffjord_mdl = FFJORD(nn, tspan, Tsit5())
data_dist = Normal(6.0f0, 0.7f0)
train_data = gpu(rand(data_dist, 1, 100))
function loss(θ)
logpx, λ₁, λ₂ = ffjord_mdl(train_data, θ)
-mean(logpx)
end
loss(ffjord_mdl.p)
gives the following error
ERROR: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(conj), Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(-), Tuple{Int64, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(Base.literal_pow), Tuple{CUDA.CuRefValue{typeof(^)}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, CUDA.CuRefValue{Val{2}}}}}}}}}}, Int64) failed
KernelError: passing and using non-bitstype argument
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(conj), Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(-), Tuple{Int64, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(Base.literal_pow), Tuple{CUDA.CuRefValue{typeof(^)}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, CUDA.CuRefValue{Val{2}}}}}}}}}}, which is not isbits:
.args is of type Tuple{Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(conj), Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(-), Tuple{Int64, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(Base.literal_pow), Tuple{CUDA.CuRefValue{typeof(^)}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, CUDA.CuRefValue{Val{2}}}}}}}}} which is not isbits.
.1 is of type Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}} which is not isbits.
.x is of type Matrix{Float32} which is not isbits.
Stacktrace:
[1] check_invocation(job::GPUCompiler.CompilerJob)
@ GPUCompiler /data/packages/GPUCompiler/fG3xK/src/validation.jl:66
[2] macro expansion
@ /data/packages/GPUCompiler/fG3xK/src/driver.jl:318 [inlined]
[3] macro expansion
@ /data/packages/TimerOutputs/ZQ0rt/src/TimerOutput.jl:236 [inlined]
[4] macro expansion
@ /data/packages/GPUCompiler/fG3xK/src/driver.jl:317 [inlined]
[5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
@ GPUCompiler /data/packages/GPUCompiler/fG3xK/src/utils.jl:62
[6] cufunction_compile(job::GPUCompiler.CompilerJob)
@ CUDA /data/packages/CUDA/DL5Zo/src/compiler/execution.jl:317
[7] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
@ GPUCompiler /data/packages/GPUCompiler/fG3xK/src/cache.jl:89
[8] cufunction(f::GPUArrays.var"#broadcast_kernel#17", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(conj), Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(-), Tuple{Int64, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(Base.literal_pow), Tuple{CUDA.CuRefValue{typeof(^)}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, CUDA.CuRefValue{Val{2}}}}}}}}}}, Int64}}; name::Nothing, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ CUDA /data/packages/CUDA/DL5Zo/src/compiler/execution.jl:288
[9] cufunction(f::GPUArrays.var"#broadcast_kernel#17", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Float32, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, typeof(*), Tuple{Base.Broadcast.Extruded{Matrix{Float32}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(conj), Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(-), Tuple{Int64, Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, typeof(Base.literal_pow), Tuple{CUDA.CuRefValue{typeof(^)}, Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, CUDA.CuRefValue{Val{2}}}}}}}}}}, Int64}})
@ CUDA /data/packages/CUDA/DL5Zo/src/compiler/execution.jl:282
[10] macro expansion
@ /data/packages/CUDA/DL5Zo/src/compiler/execution.jl:102 [inlined]
[11] #launch_heuristic#241
@ /data/packages/CUDA/DL5Zo/src/gpuarrays.jl:17 [inlined]
[12] copyto!
@ /data/packages/GPUArrays/UBzTm/src/host/broadcast.jl:65 [inlined]
[13] copyto!
@ ./broadcast.jl:936 [inlined]
[14] copy
@ /data/packages/GPUArrays/UBzTm/src/host/broadcast.jl:47 [inlined]
[15] materialize
@ ./broadcast.jl:883 [inlined]
[16] (::Zygote.var"#1076#1077"{CuArray{Float32, 2}})(ȳ::Matrix{Float32})
@ Zygote /data/packages/Zygote/TaBlo/src/lib/broadcast.jl:105
[17] #3920#back
@ /data/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
[18] Pullback
@ /data/packages/Flux/Zz9RI/src/layers/basic.jl:148 [inlined]
[19] (::typeof(∂(λ)))(Δ::Matrix{Float32})
@ Zygote /data/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[20] Pullback
@ /data/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
[21] (::typeof(∂(applychain)))(Δ::Matrix{Float32})
@ Zygote /data/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[22] Pullback
@ /data/packages/Flux/Zz9RI/src/layers/basic.jl:37 [inlined]
[23] (::typeof(∂(applychain)))(Δ::Matrix{Float32})
@ Zygote /data/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[24] Pullback
@ /data/packages/Flux/Zz9RI/src/layers/basic.jl:39 [inlined]
[25] (::typeof(∂(λ)))(Δ::Matrix{Float32})
@ Zygote /data/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
[26] (::Zygote.var"#46#47"{typeof(∂(λ))})(Δ::Matrix{Float32})
@ Zygote /data/packages/Zygote/TaBlo/src/compiler/interface.jl:41
[27] ffjord(u::CuArray{Float32, 2}, p::CuArray{Float32, 1}, t::Float32, re::Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, e::Matrix{Float32}; regularize::Bool, monte_carlo::Bool)
@ DiffEqFlux /data/packages/DiffEqFlux/N7blG/src/ffjord.jl:179
[28] #59
@ /data/packages/DiffEqFlux/N7blG/src/ffjord.jl:194 [inlined]
[29] ODEFunction
@ /data/packages/SciMLBase/UIp7W/src/scimlfunctions.jl:334 [inlined]
[30] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5, false, CuArray{Float32, 2}, Nothing, Float32, CuArray{Float32, 1}, Float32, Float32, Float32, Float32, Vector{CuArray{Float32, 2}}, ODESolution{Float32, 3, Vector{CuArray{Float32, 2}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{CuArray{Float32, 2}}}, ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, DiffEqFlux.var"#59#64"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 1}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, FullNormal, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5, OrdinaryDiffEq.InterpolationData{ODEFunction{false, DiffEqFlux.var"#59#64"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 1}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, FullNormal, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{CuArray{Float32, 2}}, Vector{Float32}, Vector{Vector{CuArray{Float32, 2}}}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}}, DiffEqBase.DEStats}, ODEFunction{false, DiffEqFlux.var"#59#64"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 1}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, FullNormal, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, OrdinaryDiffEq.DEOptions{Float32, Float32, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Tuple{}, Tuple{}, Tuple{}}, CuArray{Float32, 2}, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32})
@ OrdinaryDiffEq /data/packages/OrdinaryDiffEq/PZbGY/src/perform_step/low_order_rk_perform_step.jl:565
[31] __init(prob::ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, DiffEqFlux.var"#59#64"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 1}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, FullNormal, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::Tsit5, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ OrdinaryDiffEq /data/packages/OrdinaryDiffEq/PZbGY/src/solve.jl:456
[32] __init(prob::ODEProblem{CuArray{Float32, 2}, Tuple{Float32, Float32}, false, CuArray{Float32, 1}, ODEFunction{false, DiffEqFlux.var"#59#64"{Bool, Bool, FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 1}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, FullNormal, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Matrix{Float32}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, alg::Tsit5, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}) (repeats 5 times)
@ OrdinaryDiffEq /data/packages/OrdinaryDiffEq/PZbGY/src/solve.jl:67
[33] #__solve#471
@ /data/packages/OrdinaryDiffEq/PZbGY/src/solve.jl:4 [inlined]
[34] __solve
@ /data/packages/OrdinaryDiffEq/PZbGY/src/solve.jl:4 [inlined]
[35] #solve_call#42
@ /data/packages/DiffEqBase/Rmj4o/src/solve.jl:61 [inlined]
[36] solve_call
@ /data/packages/DiffEqBase/Rmj4o/src/solve.jl:48 [inlined]
[37] #solve_up#44
@ /data/packages/DiffEqBase/Rmj4o/src/solve.jl:87 [inlined]
[38] solve_up
@ /data/packages/DiffEqBase/Rmj4o/src/solve.jl:78 [inlined]
[39] #solve#43
@ /data/packages/DiffEqBase/Rmj4o/src/solve.jl:73 [inlined]
[40] (::FFJORD{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}, CuArray{Float32, 1}, Flux.var"#60#62"{Chain{Tuple{Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}, Dense{typeof(tanh), CuArray{Float32, 2}, CuArray{Float32, 1}}}}}, FullNormal, Tuple{Float32, Float32}, Tuple{Tsit5}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}})(x::CuArray{Float32, 2}, p::CuArray{Float32, 1}, e::Matrix{Float32}; regularize::Bool, monte_carlo::Bool)
@ DiffEqFlux /data/packages/DiffEqFlux/N7blG/src/ffjord.jl:208
[41] FFJORD (repeats 2 times)
@ /data/packages/DiffEqFlux/N7blG/src/ffjord.jl:192 [inlined]
[42] loss(θ::CuArray{Float32, 1})
@ Main ./REPL[22]:2
[43] top-level scope
@ REPL[23]:1
[44] top-level scope
@ /data/packages/CUDA/DL5Zo/src/initialization.jl:52
That trace suggests that something isn't on the gpu notice the Matrix{Float32}
in there.
Yeah, any issues with ffjord is probably a separate issue. Let's solve this in this issue ASAP and keep ffjord for another day.
Regarding the VectorOfArray issue: Calling cu(x)
calls adapt(CuArray{Float32}, x)
, which descends into types (thanks to Adapt.jl) and ultimately calls CuArray{Float32}(...)
on the leaves: https://github.com/JuliaGPU/CUDA.jl/blob/78379e1786dba80e396ca362a7546fc6d7b488e1/src/array.jl#L429-L430. And VectorOfArray doesn't support being converted to an array with an element type:
julia> Array(VectorOfArray([rand(Float32,4),rand(Float32,4)]))
4×2 Matrix{Float32}:
0.412813 0.752041
0.642892 0.894659
0.477681 0.466676
0.082933 0.0945987
julia> Array{Float32}(VectorOfArray([rand(Float32,4),rand(Float32,4)]))
ERROR: MethodError: Cannot `convert` an object of type Vector{Float32} to an object of type Float32
Closest candidates are:
convert(::Type{T}, ::ColorTypes.Gray24) where T<:Real at /home/tim/Julia/depot/packages/ColorTypes/6m8P7/src/conversions.jl:114
convert(::Type{T}, ::ColorTypes.Gray) where T<:Real at /home/tim/Julia/depot/packages/ColorTypes/6m8P7/src/conversions.jl:113
convert(::Type{T}, ::LLVM.GenericValue, ::LLVM.LLVMType) where T<:AbstractFloat at /home/tim/Julia/depot/packages/LLVM/FrlPu/src/execution.jl:39
How did it work until the update though? Was cu
calling CuArray
instead of adapt
before?
https://github.com/SciML/RecursiveArrayTools.jl/blob/master/src/init.jl#L23-L30
cu
has been using Adapt for ages. Not sure what changed, as CuArray(....)
was calling Array
on 3.3.6 too: https://github.com/JuliaGPU/CUDA.jl/blob/964893c8cd2c7c8f73de0df1c48d3237d2e07414/src/array.jl#L240-L244
Anyway, there's clearly a method missing for VectorOfArray, can we not add it there? This should do it:
Base.Array{U}(VA::AbstractVectorOfArray{T,N,A}) where {T,U,N,A <: AbstractVector{<:AbstractVector}} = reduce(hcat,map(x->U.(x), VA.u))
Yeah, I was just having trouble figuring out which overload was missing, so I was trying to track down what the change was. That seems to fix it all locally, so the new patch should handle this fine. Thanks @ccasert !
Commenting here from the Flux issue, could it be that https://github.com/JuliaGPU/GPUArrays.jl/pull/368 (which was in GPUArrays 8/CUDA 3.4 but not GPUArrays 7/CUDA 3.3) changed things?
I'm running into issues when trying out the following GPU example: https://diffeqflux.sciml.ai/dev/GPUs/. At the line
I get this error
These are my package versions:
and this is my CUDA version:
The code works without any errors when running it on CPU. What could be the problem here? Thanks!