LuxDL / Lux.jl

Elegant & Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
456 stars 54 forks source link

Extracting part of a model, with the corresponding parameters and states #617

Closed Sleort closed 2 months ago

Sleort commented 2 months ago

Is there an easy / "canonical" / "non-hackish" way of extracting (and combining?) parts of a model and its parameters and states in Lux? If not, maybe it should be made?

For example, let's say I'd like to make a simple autoencoder, train it on some data, and then extract the encoder part for some downstream task. I could try to do it like this:

julia> using Lux, Random

julia> rng = Random.default_rng()
TaskLocalRNG()

julia> encoder = Chain(Dense(3=>2), Dense(2=>1));

julia> decoder = Chain(Dense(1=>2), Dense(2=>3));

julia> autoencoder = Chain(; encoder, decoder)
Chain(
    encoder = Chain(
        layer_1 = Dense(3 => 2),        # 8 parameters
        layer_2 = Dense(2 => 1),        # 3 parameters
    ),
    decoder = Chain(
        layer_1 = Dense(1 => 2),        # 4 parameters
        layer_2 = Dense(2 => 3),        # 9 parameters
    ),
)         # Total: 24 parameters,
          #        plus 0 states.

julia> ps, st = Lux.setup(rng, autoencoder);

train autoencoder and then extract the encoder part. But while

julia> ps.encoder, st.encoder; #works

works just fine, it is not possible to access the encoder directly from autoencoder using getproperty:

julia> autoencoder.encoder; #fails
ERROR: type Chain has no field encoder
Stacktrace:
 [1] getproperty(x::Chain{@NamedTuple{encoder::Chain{@NamedTuple{…}, Nothing}, decoder::Chain{@NamedTuple{…}, Nothing}}, Nothing}, f::Symbol)
...

(Sure, in this example I do have access to encoder, but let's assume I don't / haven't passed it around...)

What does work, is to use integer indexing instead,

julia> ps[1], st[1]; #works

julia> autoencoder[1]; #works

but that feels less elegant. And can become cumbersome and error-prone in more complex settings. Furthermore,

julia> first(ps), first(st); #works

julia> first(autoencoder) #fails
ERROR: MethodError: no method matching iterate(::Chain{@NamedTuple{encoder::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}}, Nothing}, decoder::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}}, Nothing}}, Nothing})
...

and if I had made autoencoder as a "flattened" Chain model instead,

julia> autoencoder = Chain(encoder, decoder)
Chain(
    layer_1 = Dense(3 => 2),            # 8 parameters
    layer_2 = Dense(2 => 1),            # 3 parameters
    layer_3 = Dense(1 => 2),            # 4 parameters
    layer_4 = Dense(2 => 3),            # 9 parameters
)         # Total: 24 parameters,
          #        plus 0 states.

julia> ps, st = Lux.setup(rng, autoencoder);

then indexing with a range does not work for ps and st:

julia> ps[1:2], st[1:2]
ERROR: MethodError: no method matching getindex(::@NamedTuple{…}, ::UnitRange{…})

although

julia> autoencoder[1:2];

works.


The above was done in Julia v1.10.3 and Lux v0.5.41.

avik-pal commented 2 months ago

I agree this should be added. I personally prefer the .encoder syntax. We can do it by overloading the getproperty in https://github.com/LuxDL/Lux.jl/blob/d94f947096e2062ba38004b836ef2c06bf64d504/src/layers/containers.jl#L504-L511.

The logic would be:

  1. If name is a field of the model, we just return the name. So something like model.layers always returns the named tuple.
  2. If name not in the field of the model, check for it in the NamedTuple storing the actual fields and return
  3. If none, error