JinyuCai95 / EDESC-pytorch

MIT License
23 stars 6 forks source link

About The Experiment on CIFAR10 #3

Open mengxianghan123 opened 1 year ago

mengxianghan123 commented 1 year ago

Thanks for your GREAT work!! And the released code really helps a lot!! But when I tried to replicate the experimental results on CIFAR10, I failed and only got 0.35 accuracy. It might because of the inappropriate hyper-parameter setting, or my misunderstanding on other experimental details on CIFAR10. I've tried different beta values (0.1, 1,5, 10) and d values (5, 10). And for the feature extraction on CIFAR10, here's my implementation:

def CIFAR10_features():
    cifar10 = torchvision.datasets.CIFAR10(root='/home/mxh/datasets/CIFAR10',train=True,download=False)
    model = torchvision.models.resnet50(weights='DEFAULT')
    model.fc= torch.nn.Identity()
    to_tensor = transforms.ToTensor()
    img_list = []
    label_list = []
    for idx in tqdm(range(len(cifar10))):
        img,label = cifar10[idx]
        img = to_tensor(img).unsqueeze(0)
    cifar10 = torchvision.datasets.CIFAR10(root='/home/mxh/datasets/CIFAR10',train=False,download=False)
    for idx in tqdm(range(len(cifar10))):
        img,label = cifar10[idx]
        img = to_tensor(img).unsqueeze(0)
    img_list = torch.cat(img_list, dim=0).numpy()
    label_list = np.array(label_list)
    data = {'data':img_list, 'label':label_list}
    np.save("/home/mxh/codes/EDESC/data/CIFAR10/cifar.npy", data)

Could you please give me some advice on the replication of CIFAR10? It would be extremely helpful!! Thanks a lot!!

mengxianghan123 commented 1 year ago

After adding a pre-processing step, the ACC can reach 0.457 for now. But there is still a gap between 0.627 which is reported in the paper. Could you please leave more details? It would be very helpful!

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
def CIFAR10_features():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
    cifar10 = torchvision.datasets.CIFAR10(root='/home/mxh/datasets/CIFAR10',train=True,download=False,transform=transform)
    trainloader = torch.utils.data.DataLoader(cifar10, batch_size=2048, shuffle=True, num_workers=2)
    model = torchvision.models.resnet50(weights='DEFAULT')
    model.fc= torch.nn.Identity()
    img_list = []
    label_list = []
    for img,label in trainloader:
    cifar10 = torchvision.datasets.CIFAR10(root='/home/mxh/datasets/CIFAR10',train=False,download=False,transform=transform)
    testloader = torch.utils.data.DataLoader(cifar10, batch_size=2048, shuffle=True, num_workers=2)
    for img,label in testloader:
    img_list = torch.cat(img_list, dim=0).numpy()
    label_list = torch.cat(label_list, dim=0).numpy()
    data = {'data':img_list, 'label':label_list}
    np.save("/home/mxh/codes/EDESC-pytorch-master/data/CIFAR10/cifar.npy", data)