Open denizyuret opened 5 years ago
The missing interface is that you cannot initialize your parameters/moments with KnetArray.
Before we move on do you have a suggestion @cangumeli ?
julia> H=14;W=14;C=64;N=10;
julia> moments = bnmoments();
julia> params = bnparams(C);
julia> x = KnetArray(randn(Float32,H,W,C,N));
julia> y = batchnorm(x, moments, params);
ERROR: MethodError: no method matching batchnorm4(::Array{Float64,4}, ::Array{Float64,4}, ::KnetArray{Float32,4}; moments=Knet.BNMoments(0.1, nothing, nothing, zeros, ones), training=false, cache=Knet.BNCache(nothing, nothing, nothing, nothing, nothing))
Closest candidates are:
batchnorm4(::Array{T,N} where N, ::Array{T,N} where N, ::Array{T,N} where N; o...) where T at /kuacc/users/eakyurek13/.julia/packages/Knet/3lzCR/src/batchnorm.jl:364
batchnorm4(::##1032, ::##1033, ::AutoGrad.Value{##1034}; o...) where {##1032, ##1033, ##1034} at none:0
batchnorm4(::##1032, ::AutoGrad.Value{##1033}, ::##1034; o...) where {##1032, ##1033, ##1034} at none:0
...
Stacktrace:
[1] #batchnorm#386(::Nothing, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::KnetArray{Float32,4}, ::Knet.BNMoments, ::Array{Float64,1}) at /kuacc/users/eakyurek13/.julia/packages/Knet/3lzCR/src/batchnorm.jl:95
[2] batchnorm(::KnetArray{Float32,4}, ::Knet.BNMoments, ::Array{Float64,1}) at /kuacc/users/eakyurek13/.julia/packages/Knet/3lzCR/src/batchnorm.jl:79
[3] top-level scope at none:0
Yes, but there's no problem. You need to do just this: KnetArray(bnparams(dim))
.
https://github.com/ilkerkesen/Sloth.jl/blob/master/src/layers.jl#L83
We are looking for convenient interface like RNNs for Knet.jl. For example RNN had usegpu
option even before new interface (see this). So, our current batchnorm was not consistent with that since the very beginning. The problem I meant was that.
However, with new interface, we probably need this (doesn't break old codes):
mutable struct BatchNorm
params
moments::BNMoments
end
function BatchNorm(channels::Int; usegpu=gpu()>0, eltype=Float64, o...)
w = bnparams(eltype,channels)
m = bnmoments(;o...)
BatchNorm(usegpu ? Param(KnetArray(w)) : Param(w),m)
end
I found the ONNX documents useful for interface discussions. We can also look at PyTorch, Keras, TF etc.
We can also apply dropout trick to BatchNorm which is enabling differrent behaviours in @diff context.
Batch normalization needs to use current batch's mean and variance in @diff context and to update running-mean and variance. However it needs to use running-mean and running-variance in the testing phase without updating them.
We should check the current batchnorm interface and update it to use callable objects like RNNs if necessary.