LuxDL / Boltz.jl

Accelerate your ML research using pre-built Deep Learning Models with Lux
https://luxdl.github.io/Boltz.jl/
MIT License
27 stars 3 forks source link

ResNet18: incompatible architecture and pretrained parameters #18

Closed andreuvall closed 1 month ago

andreuvall commented 1 year ago

ResNets are transformed into Lux from Metalhead using the resnet function. transform is yielding a Chain with two Chains in it, each containing a number of layers. We can also see this if we use Lux.setup on the model.

using Metalhead
using Lux
using Random

model = transform(ResNet(18).layers);
ps, st = Lux.setup(Random.default_rng(), model);
@show keys(ps)
> keys(ps) = (:layer_1, :layer_2)

There is the option to pass pretrained = true to the resnet function. However, the pretrained parameters loaded by _initialize_model are a "flattened" named tuple of 14 layers.

using Boltz

_, ps_prime, st_prime = resnet(:resnet18; pretrained = true);
@show keys(ps_prime)
> keys(ps_prime) = (:layer_1, :layer_2, :layer_3, :layer_4, :layer_5, :layer_6, :layer_7, :layer_8, :layer_9, :layer_10, :layer_11, :layer_12, :layer_13, :layer_14)

Therefore, the model architecture and the pretrained parameters are not compatible.

x = randn(Float32, 224, 224, 3, 1);
model(x, ps, st)  # this works
model(x, ps_prime, st_prime)  # but this doesn't
andreuvall commented 1 year ago

This other approach with preserve_ps_st = true seems to work. Assuming the code above has been executed already:

model_pp = transform(
    ResNet(18; pretrain = true).layers; 
    preserve_ps_st = true
);
> ┌ Warning: Preserving the state of `Flux.BatchNorm` is currently not supported. Ignoring the state.
> └ @ LuxFluxTransformExt ~/.julia/packages/Lux/1Iulg/ext/LuxFluxTransformExt.jl:269

ps_pp, st_pp = Lux.setup(Random.default_rng(), model_pp);
@show keys(ps_pp)
> keys(ps_pp) = (:layer_1, :layer_2)

@assert ps_pp.layer_1.layer_1.layer_1.weight == ps_prime.layer_1.weight
@assert ps_pp.layer_1.layer_1.layer_2.scale == ps_prime.layer_2.scale
model(x, ps_pp, st_pp)  # this works

Here I only checked that two of the pretrained parameter arrays are equal. I am also unsure of the effects of the state being ignored when loading the model and the pretrained parameters.

avik-pal commented 1 year ago

I see that is the problem. The initial weights were imported in Lux 0.4, and since some defaults changed it led to this breakage.

Here I only checked that two of the pretrained parameter arrays are equal. I am also unsure of the effects of the state being ignored when loading the model and the pretrained parameters.

States not being preserved means that your predictions won't be correct. Specify force_preserve (https://lux.csail.mit.edu/dev/api/Lux/flux_to_lux#Lux.transform) and that should do it for now