Closed andreuvall closed 1 month ago
This other approach with preserve_ps_st = true
seems to work. Assuming the code above has been executed already:
model_pp = transform(
ResNet(18; pretrain = true).layers;
preserve_ps_st = true
);
> ┌ Warning: Preserving the state of `Flux.BatchNorm` is currently not supported. Ignoring the state.
> └ @ LuxFluxTransformExt ~/.julia/packages/Lux/1Iulg/ext/LuxFluxTransformExt.jl:269
ps_pp, st_pp = Lux.setup(Random.default_rng(), model_pp);
@show keys(ps_pp)
> keys(ps_pp) = (:layer_1, :layer_2)
@assert ps_pp.layer_1.layer_1.layer_1.weight == ps_prime.layer_1.weight
@assert ps_pp.layer_1.layer_1.layer_2.scale == ps_prime.layer_2.scale
model(x, ps_pp, st_pp) # this works
Here I only checked that two of the pretrained parameter arrays are equal. I am also unsure of the effects of the state being ignored when loading the model and the pretrained parameters.
I see that is the problem. The initial weights were imported in Lux 0.4, and since some defaults changed it led to this breakage.
Here I only checked that two of the pretrained parameter arrays are equal. I am also unsure of the effects of the state being ignored when loading the model and the pretrained parameters.
States not being preserved means that your predictions won't be correct. Specify force_preserve
(https://lux.csail.mit.edu/dev/api/Lux/flux_to_lux#Lux.transform) and that should do it for now
ResNets are transformed into Lux from Metalhead using the
resnet
function.transform
is yielding aChain
with twoChain
s in it, each containing a number of layers. We can also see this if we useLux.setup
on the model.There is the option to pass
pretrained = true
to theresnet
function. However, the pretrained parameters loaded by_initialize_model
are a "flattened" named tuple of 14 layers.Therefore, the model architecture and the pretrained parameters are not compatible.