keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.63k stars 19.42k forks source link

Sharing weights across layers in keras 3 [feature request] #18821

Open nhuet opened 10 months ago

nhuet commented 10 months ago

It seems that sharing weights is not possible anymore afterwards in keras 3. We should instead share layers as explained here.

But I have a usecase where I need to share a weight

In my usecase, I transform a model by splitting activations out of each layer, that means a Dense(3, activation="relu") is transformed in a Dense(3) + Activation layer. But I need

For now I have a solution but that use private attribute since by design this is currently not possible in keras 3.

Here is an example that works for sharing kernel (I actually will use something more generic to share any weight, but this is simpler to look at):

from keras.layers import Input, Dense

def share_kernel_and_build(layer1, layer2):
    # Check the layer1 is built and the layer2 is not built
    if not layer1.built:
        raise ValueError("The first layer must already be built for sharing its kernel.")
    if layer2.built:
        raise ValueError("The second layer must not be built to get the kernel of another layer")
    # Check that input exists really (ie that the layer has already been called on a symbolic KerasTensor
    input = layer1.input  # will raise a ValueError if not existing

    # store the kernel as a layer2 variable before build (ie before the lock of layer2's weights)
    layer2.kernel = layer1.kernel
    # build the layer
    layer2(input)
    # overwrite the newly generated kernel
    kernel_to_drop = layer2.kernel
    layer2.kernel = layer1.kernel
    # untrack the not used anymore kernel  (oops: using a private attribute!)
    layer2._tracker.untrack(kernel_to_drop)

layer1 = Dense(3)
input = Input((1,))
output = layer1(input)
layer2 = Dense(3)

share_kernel_and_build(layer1, layer2)

assert layer2.kernel is layer1.kernel
assert len(layer2.weights) == 2

Notes:

fchollet commented 9 months ago

A simpler solution to your problem would be:

  1. Instantiate the new Dense layer, e.g. dense = Dense.from_config(...). (It doesn't have weights at that time)
  2. Set dense.kernel = old_layer.kernel, dense.bias = old_layer.bias, dense.built = True
  3. Just use the layer -- no new weights will be created since the layer is already built
nhuet commented 9 months ago

Nice! But are we sure that the build() method does only create the weights? Perhaps i will miss something else by skipping build() ? I would like a solution that works with any layer. By setting self.built = True, I skip the build() and thus do not overwrite the weights, but is there anything else that could be important not to bypass so that the call() works ? At least, it seems build() sets also input_spec attribute, but perhaps this will not be too much of a loss (and i can also copy it from previous layer)

nhuet commented 8 months ago

A simpler solution to your problem would be:

  1. Instantiate the new Dense layer, e.g. dense = Dense.from_config(...). (It doesn't have weights at that time)
  2. Set dense.kernel = old_layer.kernel, dense.bias = old_layer.bias, dense.built = True
  3. Just use the layer -- no new weights will be created since the layer is already built

It does not work anymore from keras 3.0.3 since Dense.kernel is now a property not settable...

fchollet commented 5 months ago

We'll add a setter for the kernel.

nhuet commented 5 months ago

Thx!

fchollet commented 5 months ago

The setter thing turned out to be problematic. What I would recommend is just direct setting but use ._kernel instead of .kernel.

Ref: https://github.com/keras-team/keras/pull/19469