chengchingwen / Transformers.jl

Julia Implementation of Transformer models
MIT License
526 stars 75 forks source link

Incorrect method signature with Embed and CUDA #50

Closed janheinrichmerker closed 3 years ago

janheinrichmerker commented 3 years ago

I use this fantastic library for a small study in university. Now, when I use a trainable embedding on the GPU, i get the following error:

ERROR: LoadError: MethodError: no method matching scatter_add!(::CuArray{Float32,2}, ::CuArray{Float32,2}, ::Array{Int64,1})
Closest candidates are:
  scatter_add!(::CuArray{T,2}, ::CuArray{T,N} where N, ::CuArray{Int64,N} where N) where T at /home/me/.julia/packages/Transformers/ko7g9/src/cuda/scatter_gpu.jl:3
  scatter_add!(::CuArray{T,N} where N, ::CuArray{T,N} where N, ::CuArray{var"#s62",N} where N where var"#s62"<:Tuple) where T at /home/me/.julia/packages/Transformers/ko7g9/src/cuda/scatter_gpu.jl:32
  scatter_add!(::Array{T,2}, ::Array{T,N} where N, ::Array{Int64,N} where N) where T at /home/me/.julia/packages/Transformers/ko7g9/src/fix/scatter.jl:2
Stacktrace:
 [1] ∇gather(::CuArray{Float32,2}, ::CuArray{Float32,2}, ::Array{Int64,1}) at /home/me/.julia/packages/Transformers/ko7g9/src/basic/embeds/gather.jl:41
 [2] (::Transformers.Basic.var"#33#34"{CuArray{Float32,2},Array{Int64,1}})(::CuArray{Float32,2}) at /home/me/.julia/packages/Transformers/ko7g9/src/basic/embeds/gather.jl:55
 [3] (::Transformers.Basic.var"#294#back#35"{Transformers.Basic.var"#33#34"{CuArray{Float32,2},Array{Int64,1}}})(::CuArray{Float32,2}) at /home/me/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [4] Embed at /home/me/.julia/packages/Transformers/ko7g9/src/basic/embeds/embed.jl:25 [inlined]
 [5] (::typeof(∂(λ)))(::CuArray{Float32,2}) at /home/me/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
 [6] Embed at /home/me/.julia/packages/Transformers/ko7g9/src/basic/embeds/embed.jl:21 [inlined]
 [7] (::typeof(∂(λ)))(::CuArray{Float32,2}) at /home/me/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
 [8] WordPositionEmbed at /home/me/project/src/model/embed.jl:13 [inlined]
 [9] (::typeof(∂(λ)))(::CuArray{Float32,2}) at /home/me/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
 [10] TransformersModel at /home/me/project/src/model/transformers.jl:44 [inlined]
 [11] (::typeof(∂(λ)))(::CuArray{Float32,2}) at /home/me/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
 [12] loss at /home/me/project/src/training.jl:60 [inlined]
 [13] (::typeof(∂(loss)))(::Float32) at /home/me/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
 [14] #103 at /home/me/project/src/training.jl:99 [inlined]
 [15] (::typeof(∂(λ)))(::Float32) at /home/me/.julia/packages/Zygote/ggM8Z/src/compiler/interface2.jl:0
 [16] (::Zygote.var"#54#55"{Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /home/me/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:172
 [17] gradient(::Function, ::Params) at /home/me/.julia/packages/Zygote/ggM8Z/src/compiler/interface.jl:49
 [18] top-level scope at timing.jl:310
 [19] top-level scope at /home/me/project/src/training.jl:98
 [20] include(::Function, ::Module, ::String) at ./Base.jl:380
 [21] include(::Module, ::String) at ./Base.jl:368
 [22] exec_options(::Base.JLOptions) at ./client.jl:296
 [23] _start() at ./client.jl:506
in expression starting at /home/me/project/src/training.jl:84

I'm using Transformers 0.1.7 with Flux 0.11.1 on Julia 1.5.3. It seems like the ∇gather function does not work with CuArrays (it also says the function is for CPU):

https://github.com/chengchingwen/Transformers.jl/blob/e7e7b74e2c1fd20656b603b644b8a8e1b99ef3ea/src/basic/embeds/gather.jl#L38-L43

Any help is highly appreciated! :+1:

chengchingwen commented 3 years ago

It looks like your index array is not on GPU. All arrays need to be moved to GPU with CUDA.cu before doing any computation.

janheinrichmerker commented 3 years ago

Thanks a lot! I had accidentially used the onehot function from Flux.jl. It was fixed after using the onehot implementation from Transformers.jl. Still quite new to machine learning in Julia and the whole GPU stuff is a bit confusing to me :smile: