MadryLab / robustness

A library for experimenting with, training and evaluating neural networks, with a focus on adversarial robustness.
MIT License
903 stars 181 forks source link

Error encountered in model_utils.make_and_restore_model() function when input is a custom dataset #80

Closed Icxa closed 3 years ago

Icxa commented 3 years ago

I have a custom dataset. I have created a subclass of class DataSet as follows: ` class My_dataset(DataSet):

def __init__(self, ds_name, data_path, std, mean, num_classes, transform_train, transform_test, custom_class):
    self.ds_name = ds_name
    self.data_path = data_path
    self.std = std
    self.mean = mean
    self.num_classes = num_classes
    self.transform_train = transform_train
    self.transform_test = transform_test
    self.custom_class = custom_class
    super().__init__(self.ds_name, 
                     self.data_path, 
                     std=self.std, 
                     mean=self.mean, 
                     num_classes=self.num_classes, 
                     transform_train=self.transform_train, 
                     transform_test=self.transform_test, 
                     custom_class = self.custom_class)

def get_model(arch, pretrained):
    self.arch = arch
    self.pretrained = pretrained
    from robustness import imagenet_models # or cifar_models
    assert not self.pretrained, "pretrained only available for ImageNet"
    return imagenet_models.__dict__[self.arch](num_classes=self.num_classes)

` Note: Somehow the first line of the code does not appear as a code here!

I create an object of class My_dataset. I obtain the train_loader, val_loader by running the method make_loaders() for the created object (ant_dataset).

But when I run the following code I encounter an error: model, _ = model_utils.make_and_restore_model(arch='resnet18', dataset=ant_dataset) Here is the error as an image. image

I have tried multiple ways to pass the input and implement get_mode() function. I am not able to understand what is going wrong. Can anyone please help?

andrewilyas commented 3 years ago

Hi @Icxa ! Looks like you're just forgetting a self in the list of arguments to get_model, in the declaration:

def get_model(arch, pretrained):

should be

def get_model(self, arch, pretrained):

Hopefully this fixes it!

Icxa commented 3 years ago

@andrewilyas Thank you very much. :)