martinsbruveris / tensorflow-image-models

TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.
https://tfimm.readthedocs.io/en/latest/
Apache License 2.0
286 stars 25 forks source link

Lora dense layer #86

Closed martinsbruveris closed 1 year ago

martinsbruveris commented 1 year ago

I am creating the PR to make it easier to discuss the code.

kevin-keraudren commented 1 year ago

@martinsbruveris I can now run the following code:

import tfimm
from tfimm.architectures import lora
from tfimm.models.factory import transfer_weights

model = lora.create_model("convnext_tiny", pretrained=True, lora_rank=4)
lora.mark_only_lora_as_trainable(model, train_bias="none")

convnext_model = tfimm.create_model("convnext_tiny")
model.merge_lora_weights()
transfer_weights(model, convnext_model)

Let me know if anything is outstanding from the MR for it to be merged.