Open pevnak opened 2 years ago
Hi All,
I have found a bug in Batchnorm, when the minibatch contains just one sample. MWE:
using Flux bn = BatchNorm(1) x = zeros(1,1) gradient(() -> sum(bn(x)), Flux.params(bn)) bn.σ²
The problem is caused by using unbiassed estimate of the variance in function tracked stats
function _track_stats!( bn, x::AbstractArray{T, N}, μ, σ², reduce_dims, ) where {T, N} V = eltype(bn.σ²) mtm = bn.momentum res_mtm = one(V) - mtm m = prod(size(x, i) for i in reduce_dims) μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) bn.μ = res_mtm .* bn.μ .+ mtm .* μnew bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²newd return nothing end
to fix it, it is sufficient to change the update the of variance to
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (max(1,m - one(V)))) .* σ²newd
I am happy to prepare a proper PR, if it helps.
That would be great, thanks!
Hi All,
I have found a bug in Batchnorm, when the minibatch contains just one sample. MWE:
The problem is caused by using unbiassed estimate of the variance in function tracked stats
to fix it, it is sufficient to change the update the of variance to
I am happy to prepare a proper PR, if it helps.