LuxDL / Lux.jl

Elegant and Performant Scientific Machine Learning in Julia
https://lux.csail.mit.edu/
MIT License
507 stars 63 forks source link

`getkeypath` and `layer_map` not fully working with model with `Parallel` layers #1068

Open dmetivie opened 5 days ago

dmetivie commented 5 days ago

Trying the example of layer_map here, I wonder how to get back a specific layer given a KeyPath. In the example doing on the parameters ps

getkeypath(ps, KeyPath(:chain, :dense_1))

works, however, with the model c it does not work

getkeypath(c, KeyPath(:chain, :dense_1))
ERROR: type Parallel has no field chain

I wondered if you intended this to work (as it would be very convenient to target specific layer and get parameters (with ps) or types (with c). Note that doing

getkeypath(c.layers, KeyPath(:chain, :dense_1))
ERROR: type Parallel has no field chain

works and on regular layers too. It looks like a dispatch like getkeypath(c::Lux.Parallel, kp) = getkeypath(c.layers, kp) could do the job (however it does not work directly).

dmetivie commented 5 days ago

Another issue I think is related as keys of layer and ps are treated differently is that layer_map fails with MaxPool layers. Using the chain from MNIST tuto

c = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))
_, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st);
ERROR: ArgumentError: keys(layer_children) == keys(ps_children) must hold. Got
keys(layer_children) => Base.OneTo(0)
keys(ps_children) => ()

The same example works when removing the two MaxPool.

avik-pal commented 4 days ago

This is mostly by design of AbstractLuxWrapperLayer. See https://github.com/LuxDL/Lux.jl/blob/bbf503374b42432324654d4701d284fa5bac74f3/src/contrib/map.jl#L67-L80 for how we do the traversal.

That said can you tell about your specific usecase? If you are trying to debug something then https://lux.csail.mit.edu/stable/api/Lux/contrib#Lux.Experimental.@debug_mode should print out the exact path. As for layer_map the layer, ps and st are already available directly to the input function

dmetivie commented 3 days ago

Thanks, I did look up quite a bit into map.jl, but could not understand everything. To me, it looks like a bug that the layer_map errors with MaxPool layers.

My use case: I am trying to implement a Concrete Dropout (CD) layer. It is basically a Dropout with trainable rate. See here for PyTorch and TensorFlow implementation. I tried a Julia implementation with Flux and lately Lux, but I struggle with a few things. First, unrelated, I wanted to use Package extension to load conditionally the FLux or Lux version. I did not succeed, so last version is just Lux.

I implemented the CD layer, however in the original implementation they add a regularization term in the loss that depends on

To get automatically the relevant layers path i.e. all CD layers and the layer just before (where CD is applied), I wanted to design a layer_map like function to call before training. This path function will enable easy access to these layers coefficients during training i.e. weights and CD rates. At time pre v1.0, I used @layer_map, and it was working with a lot of hack (the path was a string if I remembered correctly). Here is the code pre v1.0 Lux.

To update post v1.0 I tried

function get_key_type!(kp_cd, kp_layer, t_layer, l, ps, st, name, name_prev, t_prev)
  if l isa Dropout 
    # here example just with Dropout so anyone can test without `ConcreteDropout.jl`
    push!(kp_cd, name)  
    push!(kp_layer, name_prev)
    push!(t_layer, t_prev)
  end
  return l, ps, st
end

function layer_map_with_previous(l, ps, st)
  kp_cd = KeyPath[]
  kp_layer = KeyPath[]
  t_layer = AbstractLuxLayer[]
  kp_prev = KeyPath(1)
  t_prev = Dense(1=>1)
  Lux.Functors.fmap_with_path(l, ps, st; walk=Lux.Experimental.LayerWalkWithPath()) do kp, layer, ps_, st_
    l__, ps__, st__ = get_key_type!(kp_cd, kp_layer, t_layer, layer, ps_, st_, kp, kp_prev, t_prev)
    kp_prev = kp
    t_prev = layer
    return l__, ps__, st__ # needed for the code not to error but useless here
  end
  return kp_cd, kp_layer, t_layer
end
m = Chain(
Dense(10=>100), 
Dropout(0.5), 
Dense(100=>2)
)
ps, st = Lux.setup(rng, m)
key_CD, key_layer_before, type_of_layer_before = layer_map_with_previous(m, ps, st)

Now, we can get all the weights to put in the loss with getkeypath

getkeypath(ps, key_layer_before[1]).weight

This work as intended, but with a MaxPool layer (maybe others?) it does not.

There is probably a simpler way to code all that. Do you have any idea how to do that? BTW I don't know where to put this layer (probably to specific to directly put in Lux.jl)