martinsbruveris / tensorflow-image-models

TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.
https://tfimm.readthedocs.io/en/latest/
Apache License 2.0
287 stars 25 forks source link

Adding support for LoRa #85

Open martinsbruveris opened 1 year ago

martinsbruveris commented 1 year ago

I am using this issue to think through the options of adding support for LoRa.

The first fundamental question is: Can we do this without changing the model code itself? I.e., let's assume we are not allowed to touch convnext.py. How would we proceed?

Ultimately our goal would be to replace (all? some?) tf.keras.layers.Conv2D layers with our new LoraConv2D layers. We could try and do this via monkey patching.

Note: All code samples are based on the convnext-tag branch.

import tensorflow as tf
from tfimm.architectures import convnext

class LoraConv2D(tf.keras.layers.Conv2D):
    ...

def main():
    cls, cfg = convnext.convnext_atto()

    # Monkey-patching conv layer
    old_conv_layer = convnext.tf.keras.layers.Conv2D
    convnext.tf.keras.layers.Conv2D = LoraConv2D

    model = cls(cfg=cfg)
    model(model.dummy_inputs)

    # Reversing changes. This would become a context manager of course.
    convnext.tf.keras.layers.Conv2D = old_conv_layer

    # stem[0] is the first convolutional layer in the stem
    print(type(model.stem[0]))

if __name__ == "__main__":
    main()

This works, but it has some drawbacks:

We would like to achieve layer-wise control, i.e., swap only some layers, but not others. We could do that by specifying layers to be swapped by their names.