thuml / Transfer-Learning-Library

Transfer Learning Library for Domain Adaptation, Task Adaptation, and Domain Generalization
http://transfer.thuml.ai
MIT License
3.39k stars 553 forks source link

Customizing my efficient net for domain adaptation #139

Closed sparshgarg23 closed 2 years ago

sparshgarg23 commented 2 years ago

I am interested in applying DANN on efficient net. My current efficient net model based on timm libary is shown below

class AMT_Model_1(nn.Module):

    def __init__(self):
        super(AMT_Model_1,self).__init__()
        self.model=timm.create_model('tf_efficientnet_b3_ns',pretrained=True,num_classes=0)=>backbone
        for param in self.model.parameters():
          param.requires_grad=False
        self.bottle_neck=nn.Sequential(
            nn.Linear(in_features=self.model.num_features,out_features=625),
            nn.ReLU(),
            nn.Dropout(p=0.3),
            nn.Linear(in_features=625,out_features=256),
            nn.ReLU(),
            nn.Dropout(p=0.2))

        self.classifier=nn.Linear(in_features=256,out_features=3)

    def forward(self,x):
        x=self.model(x)
        x=self.bottle_neck(x)
        out=self.classifier(x)
        return out

def build_model():
    model=AMT_Model_1()
    return model

As per the docs,customized backbones must follow the below pattern

import torch.nn as nn

class FeatureExtractor(nn.Module):

    def __init__(self):
        pass

    def forward(self, x):
        pass

    @property
    def out_features(self):
        pass

your_backbone = FeatureExtractor()

As such what changes should I make in my model to ensure that it works correctly with your library.

thucbx99 commented 2 years ago

You can refer to the Classifier class to implement the corresponding method, such as get_parameters. Or you can directly modify the code in dann.py to fit your implementation of efficient net.