chengchingwen / Transformers.jl

Julia Implementation of Transformer models
MIT License
526 stars 75 forks source link

Basic.onehot #75

Closed gerhardJaeger closed 3 years ago

gerhardJaeger commented 3 years ago

I try to run the tutorial, but something seems to be wrong with the onehot function call inside the definition of loss. Error message is below.

Help is appreciated. Thanks, Gerhard

[ Info: start training ERROR: LoadError: MethodError: no method matching OneHotArray(::Transformers.Basic.OneHot{0x0000000d}) Closest candidates are: OneHotArray(::Any, ::Transformers.Basic.OneHot) at /home/gjaeger/.julia/packages/Transformers/3YgSd/src/basic/embeds/onehot.jl:115 OneHotArray(::Any, ::Any) at /home/gjaeger/.julia/packages/Transformers/3YgSd/src/basic/embeds/onehot.jl:114 OneHotArray(::A) where {K, A<:(AbstractArray{Transformers.Basic.OneHot{K}, N} where N)} at /home/gjaeger/.julia/packages/Transformers/3YgSd/src/basic/embeds/onehot.jl:112 Stacktrace: [1] OneHotArray(k::Int64, xs::Int64) @ Transformers.Basic ~/.julia/packages/Transformers/3YgSd/src/basic/embeds/onehot.jl:114 [2] onehot(v::Vocabulary{Int64}, x::CuArray{Int64, 2, CUDA.Mem.DeviceBuffer}) @ Transformers.Basic ~/.julia/packages/Transformers/3YgSd/src/basic/embeds/vocab.jl:95 [3] _pullback @ ~/.julia/packages/Zygote/l3aNG/src/lib/grad.jl:8 [inlined] [4] _pullback @ ~/projects/research/tryTransformer/code/tutorial.jl:99 [inlined] [5] _pullback(::Zygote.Context, ::typeof(loss), ::CuArray{Int64, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Int64, 2, CUDA.Mem.DeviceBuffer}) @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0 [6] _pullback @ ~/projects/research/tryTransformer/code/tutorial.jl:133 [inlined] [7] _pullback(::Zygote.Context, ::var"#8#10") @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0 [8] pullback(f::Function, ps::Zygote.Params) @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:343 [9] gradient(f::Function, args::Zygote.Params) @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:75 [10] train!() @ Main ~/projects/research/tryTransformer/code/tutorial.jl:133

chengchingwen commented 3 years ago

This is same as #66. The quickest fix would be replace label = onehot(vocab, y) with label = Basic.OneHotArray(length(vocab), y). This is due to some api change for vocab and onehot.

notice: This fix is only need and will only work for this example. It's because the this tutorial use Int as vocabulary. If you are working with some real data (i.e. String vocabulary), you won't need this change

gerhardJaeger commented 3 years ago

This works, thanks a lot!