THUDM / SwissArmyTransformer

SwissArmyTransformer is a flexible and powerful library to develop your own Transformer variants.
https://THUDM.github.io/SwissArmyTransformer
Apache License 2.0
951 stars 90 forks source link

How to embed video encoder module from pytorch? #148

Open zyhzyh88 opened 9 months ago

1049451037 commented 9 months ago

Need more detailed information... What do you mean by "embed video encoder module from pytorch"?

zyhzyh88 commented 9 months ago

Sorry! I already have a video encoder written in pytorch, how can I fully embed this module into the sat framework?

1049451037 commented 9 months ago

Just replace the model with your pytorch module in fine-tuning script: (Because sat models are just normal pytorch modules)

https://github.com/THUDM/SwissArmyTransformer/blob/e6abe2a9e7bc273d4dc59560a00fa0543d07919a/examples/vit/finetune_vit_cifar10.py#L92

One more thing, maybe you need to add a disable_untrainable_params function to your model, to control what parameters you want to train:

def disable_untrainable_params(self):
    total_trainable = 0
    enable = ['mlp']
    for n, p in self.named_parameters():
        flag = False
        for e in enable:
            if e.lower() in n.lower():
                flag = True
                break
        if not flag:
            p.requires_grad_(False)
        else:
            total_trainable += p.numel()
            print_rank0(n)
    print_rank0("***** Total trainable parameters: "+str(total_trainable)+" *****")

model.disable_untrainable_params = disable_untrainable_params