FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.53k stars 610 forks source link

Batchnorm's variance becomes NaN when minibatch contains just one sample #1992

Open pevnak opened 2 years ago

pevnak commented 2 years ago

Hi All,

I have found a bug in Batchnorm, when the minibatch contains just one sample. MWE:

using Flux
bn = BatchNorm(1)
x = zeros(1,1)
gradient(() -> sum(bn(x)), Flux.params(bn))
bn.σ²

The problem is caused by using unbiassed estimate of the variance in function tracked stats


function _track_stats!(
  bn, x::AbstractArray{T, N}, μ, σ², reduce_dims,
) where {T, N}
  V = eltype(bn.σ²)
  mtm = bn.momentum
  res_mtm = one(V) - mtm
  m = prod(size(x, i) for i in reduce_dims)

  μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N))
  σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N))

  bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
  bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²newd
  return nothing
end

to fix it, it is sufficient to change the update the of variance to

bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (max(1,m - one(V)))) .* σ²newd

I am happy to prepare a proper PR, if it helps.

ToucheSir commented 2 years ago

That would be great, thanks!