JuliaGPU / Adapt.jl

Other
90 stars 24 forks source link

Question about valid use of this package #31

Closed DrChainsaw closed 3 years ago

DrChainsaw commented 3 years ago

Would adapt(T, x::MyType) = x be considered valid use of this package? The first sentence in the readme makes it seem like it is.

Reason for asking is that it could be a solution to this issue in Flux.

No biggie if it isn't I guess as one could just create a new abstraction, but it would be a bit redundant if the above is exactly what this package is meant to do.

maleadt commented 3 years ago

Don't touch the adapt method. Normally you override adapt_structure to recurse wrappers, and for array types to work with that you can extend adapt_storage. So here you want Flux.Zeros not to materialize anywhere, which means you need to implement:

Adapt.adapt_structure(to, x::Flux.Zeros) = x

julia> using CUDA

julia> cu(Flux.Zeros())
0-dimensional Flux.Zeros{Bool,0}:
0

You can't extend adapt_storage because with both arguments non-concrete you'd run into ambiguities.

Also note you might want a specialized version for CUDA.jl's Float32Adaptor to change the Zeros eltype to a Float32, if that's required.

Also note that this would have worked out of the box if Flux.Zeros were immutable. I'm assuming you won't actually use this type on the GPU, because then it would need to be isbits (mutable types are not supported on the GPU).

maleadt commented 3 years ago

Also note that this would have worked out of the box if Flux.Zeros were immutable. I'm assuming you won't actually use this type on the GPU, because then it would need to be isbits (mutable types are not supported on the GPU).

julia> cu([1]) .+ 0 .+ Flux.Zeros()
ERROR: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceArray{Int64,1,1}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(+),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Int64,1,1},Tuple{Bool},Tuple{Int64}},Int64}},Base.Broadcast.Extruded{Flux.Zeros{Bool,0},Tuple{},Tuple{}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(+),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Int64,1,1},Tuple{Bool},Tuple{Int64}},Int64}},Base.Broadcast.Extruded{Flux.Zeros{Bool,0},Tuple{},Tuple{}}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1},Nothing,typeof(+),Tuple{Base.Broadcast.Extruded{CuDeviceArray{Int64,1,1},Tuple{Bool},Tuple{Int64}},Int64}},Base.Broadcast.Extruded{Flux.Zeros{Bool,0},Tuple{},Tuple{}}} which is not isbits.
    .2 is of type Base.Broadcast.Extruded{Flux.Zeros{Bool,0},Tuple{},Tuple{}} which is not isbits.
      .x is of type Flux.Zeros{Bool,0} which is not isbits.