denizyuret / Knet.jl

Koç University deep learning framework.
https://denizyuret.github.io/Knet.jl/latest
Other
1.43k stars 230 forks source link

Derivative of a Function That Includes @diff Macro #670

Open BariscanBozkurt opened 2 years ago

BariscanBozkurt commented 2 years ago

Hello.

I am currently using Julia Versio 1.6.3 on a Platform "OS: Linux (x86_64-pc-linux-gnu) CPU: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, GPU : CuDevice(0): Tesla T4". I am trying to implement a variational autoencoder called Gradient Origin Networks (GONs). GONs are introduced as a generative model which does not require encoders or hypernetworks. Assume Variational GON model called F. First, a zero vector z_0 is passed through the model F, and then the latent vector initialized as the minus gradient of the loss with respect to this zero vector. Therefore, the latent space is determined by only one gradient step. Let us call this latent vector as z. Then, the network parameters are optimized by using the loss with the reconstruction F(z).

I am currently performing my experiments on MNIST dataset where I linearly interpolated the images to the size of 32x32. The decoding and reparametrization functions are as follows. theta is a vector of model weights.

function reparametrize(mu, logvar)

    std = exp.(0.5 .* logvar)
    epsilon = convert(Atype, randn(F, size(mu)))
    z = mu .+ epsilon .* std

    return z
end

function decode(theta, z; batch_size = 64, training = true)

    mu = theta[1] * z .+ theta[2]
    logvar = theta[3] * z .+ theta[4]

    z = reparametrize(mu, logvar)

    z = reshape(z, (1, 1, nz, batch_size))
    z = deconv4(theta[5], z, mode = 1) .+ theta[6]
    z = batchnorm(z, bnmoments(), theta[7]; training = training)
    z = Knet.elu.(z)

    z = deconv4(theta[8], z, stride = 2, padding = 1, mode = 1) .+ theta[9]
    z = batchnorm(z, bnmoments(), theta[10]; training = training)
    z = Knet.elu.(z)

    z = deconv4(theta[11], z, stride = 2, padding = 1, mode = 1) .+ theta[12]
    z = batchnorm(z, bnmoments(), theta[13]; training = training)
    z = Knet.elu.(z)

    z = deconv4(theta[14], z, stride = 2, padding = 1, mode = 1) .+ theta[15]
    x_hat = Knet.sigm.(z)

    return x_hat, mu, logvar

end

For the loss, it is used binary cross-entropy and KL-divergence. The code is given as follows.

function BCE(x_tensor,x_hat_tensor)
    x = mat(x_tensor)
    x_hat = mat(x_hat_tensor)
    return -mean(sum((x .* log.(x_hat .+ F(1e-10)) + (1 .- x) .* log.(1 .- x_hat .+ F(1e-10))), dims = 1))
end

function KLD(mu, logvar)
    var = exp.(logvar)
    std = sqrt.(var)
    KL = -0.5 * mean(sum(1 .+ logvar .- (mu .* mu) - exp.(logvar), dims = 1))
    return KL
end

function loss(theta, x, z)
    x_hat, mu, logvar = decode(theta, z)
    L = BCE(x, x_hat) + KLD(mu, logvar)
    return L
end

Since there are two steps for GON (1-) Use the gradient w.r.t. origin to determine the latent space z, 2-) Use latent space for reconstruction) I need to track all the gradient w.r.t. model weights from the steps (1) and (2). Therefore, I wrote the following decoding function and loss function for training purpose.

function decode_train(theta, x; batch_size = 64,training = true)
    origin = param(Atype(zeros(nz, batch_size)))

    derivative_origin = @diff loss(value.(theta), x, origin)
    dz = grad(derivative_origin, origin)

    z = -value(dz)

    x_hat, mu, logvar = decode(theta, origin);
    return x_hat, mu, logvar
end

function loss_train(theta, x)
    x_hat, mu, logvar = decode_train(theta, x)
    L = BCE(x, x_hat) + KLD(mu, logvar)
    return L
end

However, I am not able to take the gradient of the " loss_train(theta, x)" function. I am getting the following error when I use the @diff macro of AutoGrad package. How can I handle to train this model which requires a second order derivative (I need the derivative of the function decode_train)? To reproduce this result, you can run the following notebook : https://github.com/BariscanBozkurt/Gradient-Origin-Networks/blob/main/GON_Implementation_Issue.ipynb My code: @diff loss_train(theta, x) The error is:

Stacktrace: [1] copyto!(a::KnetArray{Float32, 4}, b::Base.Broadcast.Broadcasted{Base.Broadcast.Style{AutoGrad.Value}, NTuple{4, Base.OneTo{Int64}}, typeof(identity), Tuple{AutoGrad.Result{KnetArray{Float32, 4}}}}) @ Knet.KnetArrays ~/.julia/packages/Knet/RCkV0/src/knetarrays/broadcast.jl:35 [2] copyto!(x::AutoGrad.Result{KnetArray{Float32, 4}}, y::Base.Broadcast.Broadcasted{Base.Broadcast.Style{AutoGrad.Value}, NTuple{4, Base.OneTo{Int64}}, typeof(identity), Tuple{AutoGrad.Result{KnetArray{Float32, 4}}}}) @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:55 [3] materialize! @ ./broadcast.jl:894 [inlined] [4] materialize! @ ./broadcast.jl:891 [inlined] [5] materialize!(dest::AutoGrad.Result{KnetArray{Float32, 4}}, x::AutoGrad.Result{KnetArray{Float32, 4}}) @ Base.Broadcast ./broadcast.jl:887 [6] batchnorm4_back(g::KnetArray{Float32, 4}, x::AutoGrad.Result{KnetArray{Float32, 4}}, dy::AutoGrad.Result{KnetArray{Float32, 4}}; eps::Float64, training::Bool, cache::Knet.Ops20.BNCache, moments::Knet.Ops20.BNMoments, o::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ Knet.Ops20 ~/.julia/packages/Knet/RCkV0/src/ops20/batchnorm.jl:262 [7] #batchnorm4x#191 @ ~/.julia/packages/Knet/RCkV0/src/ops20/batchnorm.jl:317 [inlined] [8] #back#210 @ ./none:0 [inlined] [9] differentiate(::Function; o::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:165 [10] differentiate @ ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:135 [inlined] [11] decode_train(theta::Vector{Any}, x::KnetArray{Float32, 4}; batch_size::Int64, training::Bool) @ Main ./In[14]:4 [12] decode_train @ ./In[14]:2 [inlined] [13] loss_train(theta::Vector{Any}, x::KnetArray{Float32, 4}) @ Main ./In[16]:2 [14] (::var"#16#17")() @ Main ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:205 [15] differentiate(::Function; o::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:144 [16] differentiate(::Function) @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:135 [17] top-level scope @ In[18]:1 [18] eval @ ./boot.jl:360 [inlined] [19] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String) @ Base ./loading.jl:1116 [20] softscope_include_string(m::Module, code::String, filename::String) @ SoftGlobalScope ~/.julia/packages/SoftGlobalScope/u4UzH/src/SoftGlobalScope.jl:65 [21] execute_request(socket::ZMQ.Socket, msg::IJulia.Msg) @ IJulia ~/.julia/packages/IJulia/e8kqU/src/execute_request.jl:67 [22] #invokelatest#2 @ ./essentials.jl:708 [inlined] [23] invokelatest @ ./essentials.jl:706 [inlined] [24] eventloop(socket::ZMQ.Socket) @ IJulia ~/.julia/packages/IJulia/e8kqU/src/eventloop.jl:8 [25] (::IJulia.var"#15#18")() @ IJulia ./task.jl:411 MethodError: no method matching copyto!(::KnetArray{Float32, 4}, ::AutoGrad.Result{KnetArray{Float32, 4}}) Closest candidates are: copyto!(::KnetArray{T, N} where N, ::Array{T, N} where N) where T at /kuacc/users/bbozkurt15/.julia/packages/Knet/RCkV0/src/knetarrays/copy.jl:10 copyto!(::KnetArray{T, N} where N, ::Array{S, N} where N) where {T, S} at /kuacc/users/bbozkurt15/.julia/packages/Knet/RCkV0/src/knetarrays/copy.jl:18 copyto!(::KnetArray{T, N} where N, ::KnetArray{T, N} where N) where T at /kuacc/users/bbozkurt15/.julia/packages/Knet/RCkV0/src/knetarrays/copy.jl:9 ...

Stacktrace: [1] differentiate(::Function; o::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:148 [2] differentiate(::Function) @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:135 [3] top-level scope @ In[18]:1 [4] eval @ ./boot.jl:360 [inlined] [5] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String) @ Base ./loading.jl:1116

denizyuret commented 2 years ago

Dear Barışcan, could you try the following:

  1. Try your model without the batchnorm operations.
  2. Try your model with Array and/or CuArray instead of KnetArray for array type.
  3. Send me a minimal working example https://en.wikipedia.org/wiki/Minimal_working_example: complete source file or notebook which I can directly run to get the error. Your explanations above are good, but I cannot run the code.

Here is an mwe that I tried with the same logic, but it does not reproduce the error:

using AutoGrad, Random, CUDA, Knet 

ZDIM=50                                                                                                                             
XDIM=100                                                                                                                            
BATCH=10                                                                                                                            
atype = Array{Float64}                                                                                                              

target = atype(randn(XDIM))                                                                                                         
w = Param(atype(randn(XDIM,ZDIM)))                                                                                                  
z0 = Param(atype(zeros(ZDIM)))                                                                                                      
decoder(z) = w*z                                                                                                                    
qloss(x,y) = sum(abs2, x .- y)                                                                                                      

function loss(x)                                                                                                                    
    d = @diff qloss(x, decoder(z0))                                                                                                 
    z = -grad(d,z0)                                                                                                                 
    qloss(x, decoder(z))                                                                                                            
end                                                                                                                                 

J = @diff loss(target)                                                                                                              
grad(J, w) |> summary |> println                                                                                                    
BariscanBozkurt commented 2 years ago

Hello again,

I appreciate your comments, which were helpful for me to understand several issues.

When I do not use the batchnorm function, I was able to take the derivative of my loss_train(theta, x) function. I test both Array{Float32} and KnetArray{Float32} for my array type. Both work fine if I do not include batch normalization in my model code. However, batch normalization is an important component of the model in my opinion. When I modify your example as in the following, I get exactly the same error (which might serve as a minimal working example in this case)

using AutoGrad, Random, CUDA, Knet 

ZDIM=50                                                                                                                             
XDIM=100                                                                                                                            
BATCH=10                                                                                                                            
# atype = Array{Float64}      
atype = (CUDA.functional() ? KnetArray{Float32} : Array{Float32})

target = atype(randn(XDIM,BATCH)) # I ADDED A BATCH DIMENSION TO target
w = Param(atype(randn(XDIM,ZDIM)))
bparam = Param(atype((bnparams(XDIM)))) # HERE I DEFINE A PARAMETER FOR BATCH NORMALIZATION
z0 = Param(atype(zeros(ZDIM,BATCH))) # I ADDED A BATCH DIMENSION TO z0 
decoder(z) = batchnorm(w*z, bnmoments(), bparam; training = true ) # AFTER A LINEAR LAYER, I ADDED BATCH NORMALIZATION OPERATION
qloss(x,y) = sum(abs2, x .- y)   

function loss(x)                                                                                                                    
    d = @diff qloss(x, decoder(z0))                                                                                                 
    z = -grad(d,z0)                                                                                                                 
    qloss(x, decoder(z))                                                                                                            
end    

J = @diff loss(target)                                                                                                              
grad(J, w) |> summary |> println   

Now, I conclude that the error is related to the batch normalization layer. Without batchnorm, the optimization of the model goes fine. However, the final performance of the model is worse compared to the one which uses batch normalization (in Pytorch). Therefore, I cannot obtain exatly the same results given in the offical code for GON (https://github.com/BariscanBozkurt/GON/blob/master/Variational-GON.py). How can I arrange my code (or the modified example code I provided in this comment) to take the derivative of the loss which uses a model with a batch normalization layer?

BariscanBozkurt commented 2 years ago

Hello again,

I appreciate your comments, which were helpful for me to understand several issues.

When I do not use the batchnorm function, I was able to take the derivative of my loss_train(theta, x) function. I test both Array{Float32} and KnetArray{Float32} for my array type. Both work fine if I do not include batch normalization in my model code. However, batch normalization is an important component of the model in my opinion. When I modify your example as in the following, I get exactly the same error (which might serve as a minimal working example in this case)

using AutoGrad, Random, CUDA, Knet 

ZDIM=50                                                                                                                             
XDIM=100                                                                                                                            
BATCH=10                                                                                                                            
# atype = Array{Float64}      
atype = (CUDA.functional() ? KnetArray{Float32} : Array{Float32})

target = atype(randn(XDIM,BATCH)) # I ADDED A BATCH DIMENSION TO z0                                                                                                       
w = Param(atype(randn(XDIM,ZDIM)))
bparam = Param(atype((bnparams(XDIM)))) # HERE I DEFINE A PARAMETER FOR BATCH NORMALIZATION
z0 = Param(atype(zeros(ZDIM,BATCH))) # I ADDED A BATCH DIMENSION TO z0 
decoder(z) = batchnorm(w*z, bnmoments(), bparam; training = true ) # AFTER A LINEAR LAYER, I ADDED BATCH NORMALIZATION OPERATION
qloss(x,y) = sum(abs2, x .- y)   

function loss(x)                                                                                                                    
    d = @diff qloss(x, decoder(z0))                                                                                                 
    z = -grad(d,z0)                                                                                                                 
    qloss(x, decoder(z))                                                                                                            
end    

J = @diff loss(target)                                                                                                              
grad(J, w) |> summary |> println   

Now, I conclude that the error is related to the batch normalization layer. Without batchnorm, the optimization of the model goes fine. However, the final performance of the model is worse compared to the one which uses batch normalization (in Pytorch). Therefore, I cannot obtain exatly the same results given in the offical code for GON (https://github.com/cwkx/GON). How can I arrange my code (or the modified example code I provided in this comment) to take the derivative of the loss which uses a model with a batch normalization layer?

In the above minimal working example, I use KnetArray{Float32} for my array type. Now, I realized if use Array{Float32} as in the following code

using AutoGrad, Random, CUDA, Knet 

ZDIM=50                                                                                                                             
XDIM=100                                                                                                                            
BATCH=10                                                                                                                            
atype = Array{Float32}      
# atype = (CUDA.functional() ? KnetArray{Float32} : Array{Float32})

target = atype(randn(XDIM, BATCH))                                                                                                         
w = Param(atype(randn(XDIM,ZDIM)))
bparam = Param(atype((bnparams(XDIM)))) # HERE I DEFINE A PARAMETER FOR BATCH NORMALIZATION
z0 = Param(atype(zeros(ZDIM,BATCH))) # I ADDED A BATCH DIMENSION TO z0 
decoder(z) = batchnorm(w*z, bnmoments(), bparam; training = true ) # AFTER A LINEAR LAYER, I ADDED BATCH NORMALIZATION OPERATION
qloss(x,y) = sum(abs2, x .- y)   

function loss(x)                                                                                                                    
    d = @diff qloss(x, decoder(z0))                                                                                                 
    z = -grad(d,z0)                                                                                                                 
    qloss(x, decoder(z))                                                                                                            
end    

J = @diff loss(target)                                                                                                              
grad(J, w) |> summary |> println  

I get the following error (which is a very similar error with the previous one)

Stacktrace:
  [1] setindex!
    @ ./array.jl:845 [inlined]
  [2] setindex!
    @ ./multidimensional.jl:639 [inlined]
  [3] macro expansion
    @ ./broadcast.jl:984 [inlined]
  [4] macro expansion
    @ ./simdloop.jl:77 [inlined]
  [5] copyto!
    @ ./broadcast.jl:983 [inlined]
  [6] copyto!
    @ ./broadcast.jl:936 [inlined]
  [7] copyto!(x::AutoGrad.Result{Array{Float32, 4}}, y::Base.Broadcast.Broadcasted{Base.Broadcast.Style{AutoGrad.Value}, NTuple{4, Base.OneTo{Int64}}, typeof(identity), Tuple{AutoGrad.Result{Array{Float32, 4}}}})
    @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:55
  [8] materialize!
    @ ./broadcast.jl:894 [inlined]
  [9] materialize!
    @ ./broadcast.jl:891 [inlined]
 [10] materialize!(dest::AutoGrad.Result{Array{Float32, 4}}, x::AutoGrad.Result{Array{Float32, 4}})
    @ Base.Broadcast ./broadcast.jl:887
 [11] batchnorm4_back(g::AutoGrad.Result{Array{Float32, 4}}, x::AutoGrad.Result{Array{Float32, 4}}, dy::AutoGrad.Result{Array{Float32, 4}}; eps::Float64, training::Bool, cache::Knet.Ops20.BNCache, moments::Knet.Ops20.BNMoments, o::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Knet.Ops20 ~/.julia/packages/Knet/RCkV0/src/ops20/batchnorm.jl:262
 [12] #batchnorm4g#189
    @ ~/.julia/packages/Knet/RCkV0/src/ops20/batchnorm.jl:296 [inlined]
 [13] #back#196
    @ ./none:0 [inlined]
 [14] differentiate(::Function; o::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:165
 [15] differentiate
    @ ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:135 [inlined]
 [16] loss(x::Vector{Float32})
    @ Main ./In[4]:17
 [17] (::var"#15#16")()
    @ Main ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:205
 [18] differentiate(::Function; o::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:144
 [19] differentiate(::Function)
    @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:135
 [20] top-level scope
    @ In[4]:22
 [21] eval
    @ ./boot.jl:360 [inlined]
 [22] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1116
 [23] softscope_include_string(m::Module, code::String, filename::String)
    @ SoftGlobalScope ~/.julia/packages/SoftGlobalScope/u4UzH/src/SoftGlobalScope.jl:65
 [24] execute_request(socket::ZMQ.Socket, msg::IJulia.Msg)
    @ IJulia ~/.julia/packages/IJulia/e8kqU/src/execute_request.jl:67
 [25] #invokelatest#2
    @ ./essentials.jl:708 [inlined]
 [26] invokelatest
    @ ./essentials.jl:706 [inlined]
 [27] eventloop(socket::ZMQ.Socket)
    @ IJulia ~/.julia/packages/IJulia/e8kqU/src/eventloop.jl:8
 [28] (::IJulia.var"#15#18")()
    @ IJulia ./task.jl:411
MethodError: Cannot `convert` an object of type AutoGrad.Result{Float32} to an object of type Float32
Closest candidates are:
  convert(::Type{T}, ::Base.TwicePrecision) where T<:Number at twiceprecision.jl:250
  convert(::Type{T}, ::AbstractChar) where T<:Number at char.jl:180
  convert(::Type{T}, ::CartesianIndex{1}) where T<:Number at multidimensional.jl:136
  ...

Stacktrace:
 [1] differentiate(::Function; o::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:148
 [2] differentiate(::Function)
   @ AutoGrad ~/.julia/packages/AutoGrad/TTpeo/src/core.jl:135
 [3] top-level scope
   @ In[4]:22
 [4] eval
   @ ./boot.jl:360 [inlined]
 [5] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
   @ Base ./loading.jl:1116
BariscanBozkurt commented 2 years ago

I strongly believe that the issue is related to the affine implementation of batchnorm function in Knet. Here, I will report all the observations I made as well as my custom solution although I might be wrong at some points (please correct me).

First of all, in our mwe, if I do not feed bparam (parameter for batch normalization used in affine implementation) inside the batchnorm function, I do not get any error and the code works fine. Therefore, the below code works fine

using AutoGrad, Random, CUDA, Knet 

ZDIM=50                                                                                                                             
XDIM=100                                                                                                                            
BATCH=10                                                                                                                            
# atype = Array{Float64}      
atype = (CUDA.functional() ? KnetArray{Float32} : Array{Float32})

target = atype(randn(XDIM,BATCH))                                                                                                         
w = Param(atype(randn(XDIM,ZDIM)))
bparam = Param(atype((bnparams(XDIM)))) # HERE I DEFINE A PARAMETER FOR BATCH NORMALIZATION
z0 = Param(atype(zeros(ZDIM,BATCH))) # I ADDED A BATCH DIMENSION TO z0 
# WE DO NOT FEED bparam VECTOR INSIDE THE BELOW BATCH NORMALIZATION FUNCTION
decoder(z) = batchnorm(w*z, bnmoments(); training = true ) # AFTER A LINEAR LAYER, I ADDED BATCH NORMALIZATION OPERATION
qloss(x,y) = sum(abs2, x .- y)   

function loss(x)                                                                                                                    
    d = @diff qloss(x, decoder(z0))                                                                                                 
    z = -grad(d,z0)                                                                                                                 
    qloss(x, decoder(z))                                                                                                            
end    

J = @diff loss(target)                                                                                                              
grad(J, w) |> summary |> println 

Since I do not use bparam, batch normalization function only uses the mu and ivar from the data (check _batchnorm4_fused in https://github.com/denizyuret/Knet.jl/blob/048587010acb3cf4ccea821a675a16d977af0b75/src/ops20/batchnorm.jl#L19-L54). Therefore, batch normalization is performed in the following sense

y .= (y .- mu) .* ivar (Eq. 1)

However, what I want is the following,

y .= g .* (y .- mu) .* ivar .+ b (Eq. 2)

Therefore, I wrote a custom batch normalization function based on the batchnorm of Knet as the following.

# Dimension helpers
@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)

function mybatchnorm(x, moments, bparam; training = true)
    bparam_dim =  size(bparam,1)
    g = reshape(bparam[1:bparam_dim/2], _wsize(x))
    b = reshape(bparam[bparam_dim/2 + 1 : bparam_dim], _wsize(x))
    return g.* batchnorm(x, moments; training = training) .+ b
end

In this function, I feed my learnable parameter bparam which includes g and b vectors which are used in Eq. 2. Then, I return the affine transformation I need. I believe that it corresponds to the exactly same thing implemented in Knet batchnorm function. However, if I use this custom batch normalization function, I do not get any error while taking the derivative of the loss function. In conclusion, the following piece of code works,

using AutoGrad, Random, CUDA, Knet 

# Dimension helpers
@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)

function mybatchnorm(x, moments, bparam; training = true)
    bparam_dim =  size(bparam,1)
    g = reshape(bparam[1:bparam_dim/2], _wsize(x))
    b = reshape(bparam[bparam_dim/2 + 1 : bparam_dim], _wsize(x))
    return g.* batchnorm(x, moments; training = training) .+ b
end

ZDIM=50                                                                                                                             
XDIM=100                                                                                                                            
BATCH=10                                                                                                                            
# atype = Array{Float64}      
atype = (CUDA.functional() ? KnetArray{Float32} : Array{Float32})

target = atype(randn(XDIM,BATCH))                                                                                                         
w = Param(atype(randn(XDIM,ZDIM)))
bparam = Param(atype((bnparams(XDIM)))) # HERE I DEFINE A PARAMETER FOR BATCH NORMALIZATION
z0 = Param(atype(zeros(ZDIM,BATCH))) # I ADDED A BATCH DIMENSION TO z0 
decoder(z) = mybatchnorm(w*z, bnmoments(), bparam; training = true ) # AFTER A LINEAR LAYER, I ADDED BATCH NORMALIZATION OPERATION
qloss(x,y) = sum(abs2, x .- y)   

function loss(x)                                                                                                                    
    d = @diff qloss(x, decoder(z0))                                                                                                 
    z = -grad(d,z0)                                                                                                                 
    qloss(x, decoder(z))                                                                                                            
end    

J = @diff loss(target)                                                                                                              
grad(J, w) |> summary |> println   

I could not figure out the reason of the error I get with affine implementation of batchnorm function in Knet. I hope my observations help to figure it out together. I will keep working on it. I will appreciate any comment to understand the main reason of the issue.