thuml / Transfer-Learning-Library

Transfer Learning Library for Domain Adaptation, Task Adaptation, and Domain Generalization
http://transfer.thuml.ai
MIT License
3.39k stars 553 forks source link

TAT low acc #136

Closed anxingzzz closed 2 years ago

anxingzzz commented 2 years ago

您好,请问一下transferable adversarial training我带入数据集train acc震荡很大并且接近1,而test acc却很低0.4左右,我用的resnet50提取两域的特征并对应label保存txt,dataloader输入TAT,参数用visda的,请问有哪些细节方面需要注意来提高acc,感谢指导。 附上我特征提取的代码,不知道是否正确:

data = datasets.ImageFolder(root=root, transform=transform2) data_loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=True, drop_last=False)

DEVICE = torch.device('cuda:0')

model = torchvision.models.resnet50(pretrained=True) n_features = model.fc.in_features model.fc = torch.nn.Linear(nfeatures, 2048) torch.nn.init.eye(model.fc.weight) model = model.to(DEVICE)

for param in model.parameters(): param.requires_grad = False

data_list = [] label_list = [] with torch.no_grad(): for inputs, labels in data_loader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) feas = model(inputs) img_path = " " + str(labels.item()) ''' labels = labels.view(labels.size(0), 1).float() y = torch.cat((feas, labels), dim=1) ''' y = feas.cpu().numpy() data_list.append(y) with open('D:\TAT\dataest\quexian\t_train.txt', "a") as f: f.write(str(img_path) + '\n')

a=numpy.concatenate(data_list,axis=0) print(a.shape) np.save('t_train.npy', a)

thucbx99 commented 2 years ago

目前TransLearn没有实现这个算法,你可以参照这里的代码https://github.com/thuml/Transferable-Adversarial-Training