FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.5k stars 607 forks source link

Weights shape not validated against kernel, channels #2506

Open BioTurboNick opened 3 days ago

BioTurboNick commented 3 days ago
weights = Flux.kaiming_normal()(3, 3, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3,), 3 => 1, pad=1)  # 10 parameters

weights = Flux.kaiming_normal()(3, 3, 1, 1)
Conv((3, 3), 1 => 1; pad = (1, 1), init = (_...) -> weights)
# Conv((3, 3), 1 => 1, pad=1)  # 10 parameters

I wanted to strictly specify the weight init for testing, but encountered this odd result. I think there should be validation to ensure that the weight shape matches the kernel size and input channels, and error if there is a mismatch.

CarloLucibello commented 3 days ago

yes, definitely those sizes should be validated.

mcabbott commented 2 days ago

You might be looking for Conv(weights; pad=(1,1))? I.e. there's a method which accepts weights::Array, for exactly this purpose. It does not take (3, 3), 1 => 1 since, as you note, these are implied by the array size.

The methods which accept an init function certainly assume size(init(s...)) == s. Maybe they can all be made to check somehow but it does seem a somewhat strange path.

The initialisation of the weight matrix is W = init(out, in), calling the function given to keyword init

BioTurboNick commented 2 days ago

Ah, that's fair, thanks. Didn't think to look for a different method signature for this.

mcabbott commented 2 days ago

The check if we do want one could look like this -- maybe not as messy as I pictured it being, at first:

function _sizecheck(f, sz::Integer...)
  W = f(sz...)
  size(W) == sz || error("bad size! (except more friendly)")
  W
end

function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity;
               init = glorot_uniform, bias = true)
  # Dense(init(out, in), bias, σ)  # current code
  Dense(_sizecheck(init, out, in), bias, σ)  # with new check
end

(When you pass an array bias = [1,2,3.] to layer constructors, I think it is always checked for size.)