SSARCandy / DeepCORAL

🧠 A PyTorch implementation of 'Deep CORAL: Correlation Alignment for Deep Domain Adaptation.', ECCV 2016
https://ssarcandy.tw/2017/10/31/deep-coral/
226 stars 42 forks source link

Pretrained model #8

Closed deep0learning closed 6 years ago

deep0learning commented 6 years ago

In your code you are using pretrained model. For an example, I will not use pretrained model, then what can I do? I am writing a data loader code. Hopefully, I will get good accuracy as paper and then I will send pull request.

Please help me so that I can use your code without pretrained model.

Thanks in advanced.

redhat12345 commented 6 years ago

Here is the code for pretrained model.

# load AlexNet pre-trained model
def load_pretrained(model):
    url = 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth'
    pretrained_dict = model_zoo.load_url(url)
    model_dict = model.state_dict()

    # filter out unmatch dict and delete last fc bias, weight
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # del pretrained_dict['classifier.6.bias']
    # del pretrained_dict['classifier.6.weight']

    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
deep0learning commented 6 years ago

@redhat12345 Thanks for your reply. Can anyone help me please how can I train the model with out pretrained model.

SSARCandy commented 6 years ago

You can remove the load_pretrained() in __main__ and CORAL loss in train()

deep0learning commented 6 years ago

@SSARCandy Thank you so much for your reply. My intention is that I will use coral loss without pretrained model. How can I do that? Thanks in advanced.

SSARCandy commented 6 years ago

Oh. Just remove the load_pretrained() in __main__ , then it will not load the ImageNet pretrained model

deep0learning commented 6 years ago

I have removed load_pretrained() in main. But got the error.

Traceback (most recent call last): File "main.py", line 160, in load_pretrained(model.sharedNet) NameError: name 'load_pretrained' is not defined

SSARCandy commented 6 years ago

all you have to do is remove line 143-146

deep0learning commented 6 years ago

@SSARCandy

Thank you so much. I want to use LeNet. But I got error. I want to do SHVN to MNIST. I have created the class lenet. Can you please tell me how can I do that? Thanks in advanced.

transform = transforms.Compose([ transforms.Resize((28,28)), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])

SVHN to MNIST

train_dataset = datasets.SVHN('SVHN', download=True, transform=transform, split='train') valid_dataset = datasets.MNIST('MNIST', download=True, transform=transform, train=True)

source_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE[0], shuffle=True) target_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=BATCH_SIZE[1], shuffle=True)

class LeNet(nn.Module): def init(self): super(LeNet, self).init() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(1655, 120) self.fc2 = nn.Linear(120, 84)

self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    out = F.relu(self.conv1(x))
    out = F.max_pool2d(out, 2)
    out = F.relu(self.conv2(out))
    out = F.max_pool2d(out, 2)
    out = out.view(out.size(0), -1)
    out = F.relu(self.fc1(out))
    out = F.relu(self.fc2(out))
    out = self.fc3(out)
    return out

I have also changed

class DeepCORAL(nn.Module): def init(self, num_classes=1000): super(DeepCORAL, self).init() self.sharedNet = LeNet() self.source_fc = nn.Linear(84, num_classes) self.target_fc = nn.Linear(84, num_classes)

    # initialize according to CORAL paper experiment
    self.source_fc.weight.data.normal_(0, 0.005)
    self.target_fc.weight.data.normal_(0, 0.005)
redhat12345 commented 6 years ago

@deep0learning

Change the input size. The input size should be 32*32

transform = transforms.Compose([ transforms.Resize((32,32)), transforms.Grayscale(num_output_channels=3), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

deep0learning commented 6 years ago

@redhat12345

Thank you so much. It's working but got poor accuracy.