Closed DrChainsaw closed 3 years ago
Don't touch the adapt
method. Normally you override adapt_structure
to recurse wrappers, and for array types to work with that you can extend adapt_storage
. So here you want Flux.Zeros
not to materialize anywhere, which means you need to implement:
Adapt.adapt_structure(to, x::Flux.Zeros) = x
julia> using CUDA
julia> cu(Flux.Zeros())
0-dimensional Flux.Zeros{Bool,0}:
0
You can't extend adapt_storage
because with both arguments non-concrete you'd run into ambiguities.
Also note you might want a specialized version for CUDA.jl's Float32Adaptor
to change the Zeros eltype to a Float32, if that's required.
Also note that this would have worked out of the box if Flux.Zeros were immutable. I'm assuming you won't actually use this type on the GPU, because then it would need to be isbits (mutable types are not supported on the GPU).
Also note that this would have worked out of the box if Flux.Zeros were immutable. I'm assuming you won't actually use this type on the GPU, because then it would need to be isbits (mutable types are not supported on the GPU).
julia> cu([1]) .+ 0 .+ Flux.Zeros()
ERROR: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceArray{Int64,1,1}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(+),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Int64,1,1},Tuple{Bool},Tuple{Int64}},Int64}},Base.Broadcast.Extruded{Flux.Zeros{Bool,0},Tuple{},Tuple{}}}}, Int64) failed
KernelError: passing and using non-bitstype argument
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(+),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Int64,1,1},Tuple{Bool},Tuple{Int64}},Int64}},Base.Broadcast.Extruded{Flux.Zeros{Bool,0},Tuple{},Tuple{}}}}, which is not isbits:
.args is of type Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(+),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Int64,1,1},Tuple{Bool},Tuple{Int64}},Int64}},Base.Broadcast.Extruded{Flux.Zeros{Bool,0},Tuple{},Tuple{}}} which is not isbits.
.2 is of type Base.Broadcast.Extruded{Flux.Zeros{Bool,0},Tuple{},Tuple{}} which is not isbits.
.x is of type Flux.Zeros{Bool,0} which is not isbits.
Would
adapt(T, x::MyType) = x
be considered valid use of this package? The first sentence in the readme makes it seem like it is.Reason for asking is that it could be a solution to this issue in Flux.
No biggie if it isn't I guess as one could just create a new abstraction, but it would be a bit redundant if the above is exactly what this package is meant to do.