denizyuret / Knet.jl

Koç University deep learning framework.
https://denizyuret.github.io/Knet.jl/latest
Other
1.43k stars 230 forks source link

"vcat" doesnt work for 3D KnetArrays #470

Closed denglerchr closed 5 years ago

denglerchr commented 5 years ago

I found "vcat" or "cat(..., dims = 1)" to not work for KnetArrays of dimensions >2. I think this would be a nice feature, as RNN deal with 3 dimensional data. P.S. Using Knet 1.2.2 on Julia 1.0.3 .

ilkerkesen commented 5 years ago

We can accomplish this with a reshape trick to make it working (at least) for first and last dimension,

A = rand(3,4,5);
B = rand(3,4,5);
cat1 = cat(A,B,dims=1);
cat3 = cat(A,B,dims=3);
cat1b = reshape(cat1a, size(A,1)+size(B,1), size(A)[2:3]...);
cat1b == cat1 # true
cat3a = hcat(reshape(A, :, size(A,3)), reshape(B, :, size(B,3)));
cat3b = reshape(cat3a, size(A)[1:2]..., size(A,3)+size(B,3));
cat3b == cat3 # true
denglerchr commented 5 years ago

Yes something like that works. I used this for a forward pass, but I ran into problems with the backward pass then.

function vcat(X::KnetArray{T,3}, Y::KnetArray{T,3}) where {T}
    @assert size(X, 2) == size(Y, 2)
    @assert size(X, 3) == size(Y, 3)

    X2 = reshape(X, size(X, 1), size(X, 2)*size(X, 3))
    Y2 = reshape(Y, size(Y, 1), size(Y, 2)*size(Y, 3))
    temp = vcat(X2, Y2)
    return reshape(temp, size(X, 1)+size(Y, 1), size(X, 2), size(X, 3))
end
ilkerkesen commented 5 years ago

Can you please paste the full error backtrace? This is related to input argument types because this function is not capable of handling backward pass types (I don't remember its name, it should be something like AutoGrad.Result).

denglerchr commented 5 years ago

It is quite a long one, some method is missing, but it says its a method for KnetArrays. I dont really understand it, it seems to be looking for a method cat but without any information on the dimension.


Stacktrace:
 [1] #cat#38(::Val{1}, ::Function, ::KnetArray{Float32,3}, ::KnetArray{Float32,3}) at C:\Users\dengl\.julia\packages\Knet\HwZrA\src\karray.jl:326
 [2] (::getfield(Base, Symbol("#kw##cat")))(::NamedTuple{(:dims,),Tuple{Val{1}}}, ::typeof(cat), ::KnetArray{Float32,3}, ::KnetArray{Float32,3}) at .\none:0
 [3] #forw#1(::Base.Iterators.Pairs{Symbol,Val{1},Tuple{Symbol},NamedTuple{(:dims,),Tuple{Val{1}}}}, ::Function, ::Function, ::KnetArray{Float32,3}, ::Vararg{Any,N} where N) at C:\Users\dengl\.julia\packages\AutoGrad\FKOf4\src\core.jl:66
 [4] #forw at .\none:0 [inlined]
 [5] #cat#28 at C:\Users\dengl\.julia\packages\Knet\HwZrA\src\karray.jl:229 [inlined]
 [6] #cat at .\none:0 [inlined]
 [7] vcat(::KnetArray{Float32,3}, ::AutoGrad.Result{KnetArray{Float32,3}}) at .\abstractarray.jl:1418
 [8] (::Chain2{Tuple{Knet.RNN{KnetArray{Float32,3}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}}}})(::KnetArray{Float32,3}) at C:\Users\dengl\Dropbox\Uni\Projekte\2019_RNN_Segway\01_Code\01_Supervised\src\Functions\rnn.jl:20
 [9] lossfunc(::KnetArray{Float32,3}, ::KnetArray{Float32,3}, ::Chain2{Tuple{Knet.RNN{KnetArray{Float32,3}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}}}}) at C:\Users\dengl\Dropbox\Uni\Projekte\2019_RNN_Segway\01_Code\01_Supervised\src\Functions\train.jl:5
 [10] (::getfield(Main, Symbol("##37#41")){Chain2{Tuple{Knet.RNN{KnetArray{Float32,3}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}}}}})(::KnetArray{Float32,3}, ::KnetArray{Float32,3}) at C:\Users\dengl\Dropbox\Uni\Projekte\2019_RNN_Segway\01_Code\01_Supervised\src\Functions\train.jl:17
 [11] (::getfield(Knet, Symbol("##664#665")){Knet.Minimize{DataSlicer{KnetArray{Float32,3},KnetArray{Float32,N} where N}},Tuple{KnetArray{Float32,3},KnetArray{Float32,3}}})() at C:\Users\dengl\.julia\packages\AutoGrad\FKOf4\src\core.jl:197
 [12] #differentiate#3(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Function) at C:\Users\dengl\.julia\packages\AutoGrad\FKOf4\src\core.jl:144
 [13] differentiate(::Function) at C:\Users\dengl\.julia\packages\AutoGrad\FKOf4\src\core.jl:135
 [14] iterate(::Knet.Minimize{DataSlicer{KnetArray{Float32,3},KnetArray{Float32,N} where N}}) at C:\Users\dengl\.julia\packages\Knet\HwZrA\src\train.jl:24
 [15] #train_rnn!#34(::Int64, ::Int64, ::Int64, ::String, ::Function, ::Rnn_data, ::Chain2{Tuple{Knet.RNN{KnetArray{Float32,3}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}}}}) at C:\Users\dengl\Dropbox\Uni\Projekte\2019_RNN_Segway\01_Code\01_Supervised\src\Functions\train.jl:21
 [16] (::getfield(Main, Symbol("#kw##train_rnn!")))(::NamedTuple{(:max_epochs, :Nvalid, :Nsave, :filename),Tuple{Int64,Int64,Int64,String}}, ::typeof(train_rnn!), ::Rnn_data, ::Chain2{Tuple{Knet.RNN{KnetArray{Float32,3}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}}}}) at .\none:0
 [17] top-level scope at util.jl:156
 [18] include at .\boot.jl:317 [inlined]
 [19] include_relative(::Module, ::String) at .\loading.jl:1044
 [20] include(::Module, ::String) at .\sysimg.jl:29
 [21] include(::String) at .\client.jl:392
 [22] top-level scope at none:0
 [23] eval(::Module, ::Any) at .\boot.jl:319
 [24] eval_user_input(::Any, ::REPL.REPLBackend) at C:\cygwin\home\Administrator\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.0\REPL\src\REPL.jl:85
 [25] macro expansion at C:\cygwin\home\Administrator\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.0\REPL\src\REPL.jl:117 [inlined]
 [26] (::getfield(REPL, Symbol("##28#29")){REPL.REPLBackend})() at .\task.jl:259
ERROR: LoadError: MethodError: no method matching cat(::KnetArray{Float32,3}, ::KnetArray{Float32,3})
Closest candidates are:
  cat(::Union{Number, AbstractArray, KnetArray}, ::Union{Number, AbstractArray, KnetArray}...; dims) at C:\Users\dengl\.julia\packages\Knet\HwZrA\src\karray.jl:326
  cat(::Union{Number, Value, AbstractArray, KnetArray}...; dims) at C:\Users\dengl\.julia\packages\Knet\HwZrA\src\karray.jl:229
  cat(::Any...; dims) at abstractarray.jl:1480
Stacktrace:
 [1] #differentiate#3(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Function) at C:\Users\dengl\.julia\packages\AutoGrad\FKOf4\src\core.jl:148
 [2] differentiate(::Function) at C:\Users\dengl\.julia\packages\AutoGrad\FKOf4\src\core.jl:135
 [3] iterate(::Knet.Minimize{DataSlicer{KnetArray{Float32,3},KnetArray{Float32,N} where N}}) at C:\Users\dengl\.julia\packages\Knet\HwZrA\src\train.jl:24
 [4] #train_rnn!#34(::Int64, ::Int64, ::Int64, ::String, ::Function, ::Rnn_data, ::Chain2{Tuple{Knet.RNN{KnetArray{Float32,3}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}}}}) at C:\Users\dengl\Dropbox\Uni\Projekte\2019_RNN_Segway\01_Code\01_Supervised\src\Functions\train.jl:21
 [5] (::getfield(Main, Symbol("#kw##train_rnn!")))(::NamedTuple{(:max_epochs, :Nvalid, :Nsave, :filename),Tuple{Int64,Int64,Int64,String}}, ::typeof(train_rnn!), ::Rnn_data, ::Chain2{Tuple{Knet.RNN{KnetArray{Float32,3}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}},Dense{Param{KnetArray{Float32,2}},Param{KnetArray{Float32,1}}}}}) at .\none:0
 [6] top-level scope at util.jl:156
 [7] include at .\boot.jl:317 [inlined]
 [8] include_relative(::Module, ::String) at .\loading.jl:1044
 [9] include(::Module, ::String) at .\sysimg.jl:29
 [10] include(::String) at .\client.jl:392
 [11] top-level scope at none:0
in expression starting at C:\Users\dengl\Dropbox\Uni\Projekte\2019_RNN_Segway\01_Code\01_Supervised\src\03_Train_RNN.jl:45
ilkerkesen commented 5 years ago

@denizyuret I use this kind of cat heavily. I'll implement it by using the reshape trick for dims=1 and dims=ndims(x) where dims(x) > 2. @denglerchr by the way, the problem is, in the backward pass, it searches for cat(::KnetArray{Float32,3}, ::KnetArray{Float32,3}) but it is not able to find it since such a method does not exist (it needs dims keyword argument).

denglerchr commented 5 years ago

Thank you very much. Yes I figured that this method is missing. However I was wondering about the message because I feel a cat method without specifying a dimension (as in the method it was looking for) does not make much sense. But I'm glad you want to take a look at it :)

denizyuret commented 5 years ago

Just wanted to add a general comment:

We have several techniques for GPU array operations missing from Knet:

Using these techniques (esp the last two) it should be easy to cover all missing array ops as a first step. As a second step we should test the performance of these options on common array sizes and decide the best solution to implement.

On Wed, Jul 3, 2019 at 11:26 AM Christian Dengler notifications@github.com wrote:

Thank you very much. Yes I figured that this method is missing. However I was wondering about the message because I feel a cat method without specifying a dimension (as in the method it was looking for) does not make much sense. But I'm glad you want to take a look at it :)

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/denizyuret/Knet.jl/issues/470?email_source=notifications&email_token=AAN43JU2YIR3RX3UUDQXDG3P5RPDVA5CNFSM4H3FJ7YKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGODZDV75I#issuecomment-507994101, or mute the thread https://github.com/notifications/unsubscribe-auth/AAN43JQ4L622GAHEY37FWKLP5RPDVANCNFSM4H3FJ7YA .

denizyuret commented 5 years ago

CuArrays fallbacks fix this in https://github.com/denizyuret/Knet.jl/commit/85c4125fa64f6407b2f4a78e4c8655083f059af1