Closed gerhardJaeger closed 3 years ago
This is same as #66. The quickest fix would be replace label = onehot(vocab, y)
with label = Basic.OneHotArray(length(vocab), y)
. This is due to some api change for vocab and onehot.
notice: This fix is only need and will only work for this example. It's because the this tutorial use Int
as vocabulary. If you are working with some real data (i.e. String
vocabulary), you won't need this change
This works, thanks a lot!
I try to run the tutorial, but something seems to be wrong with the onehot function call inside the definition of loss. Error message is below.
Help is appreciated. Thanks, Gerhard
[ Info: start training ERROR: LoadError: MethodError: no method matching OneHotArray(::Transformers.Basic.OneHot{0x0000000d}) Closest candidates are: OneHotArray(::Any, ::Transformers.Basic.OneHot) at /home/gjaeger/.julia/packages/Transformers/3YgSd/src/basic/embeds/onehot.jl:115 OneHotArray(::Any, ::Any) at /home/gjaeger/.julia/packages/Transformers/3YgSd/src/basic/embeds/onehot.jl:114 OneHotArray(::A) where {K, A<:(AbstractArray{Transformers.Basic.OneHot{K}, N} where N)} at /home/gjaeger/.julia/packages/Transformers/3YgSd/src/basic/embeds/onehot.jl:112 Stacktrace: [1] OneHotArray(k::Int64, xs::Int64) @ Transformers.Basic ~/.julia/packages/Transformers/3YgSd/src/basic/embeds/onehot.jl:114 [2] onehot(v::Vocabulary{Int64}, x::CuArray{Int64, 2, CUDA.Mem.DeviceBuffer}) @ Transformers.Basic ~/.julia/packages/Transformers/3YgSd/src/basic/embeds/vocab.jl:95 [3] _pullback @ ~/.julia/packages/Zygote/l3aNG/src/lib/grad.jl:8 [inlined] [4] _pullback @ ~/projects/research/tryTransformer/code/tutorial.jl:99 [inlined] [5] _pullback(::Zygote.Context, ::typeof(loss), ::CuArray{Int64, 2, CUDA.Mem.DeviceBuffer}, ::CuArray{Int64, 2, CUDA.Mem.DeviceBuffer}) @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0 [6] _pullback @ ~/projects/research/tryTransformer/code/tutorial.jl:133 [inlined] [7] _pullback(::Zygote.Context, ::var"#8#10") @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface2.jl:0 [8] pullback(f::Function, ps::Zygote.Params) @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:343 [9] gradient(f::Function, args::Zygote.Params) @ Zygote ~/.julia/packages/Zygote/l3aNG/src/compiler/interface.jl:75 [10] train!() @ Main ~/projects/research/tryTransformer/code/tutorial.jl:133