gpleiss / temperature_scaling

A simple way to calibrate your neural network.
MIT License
1.09k stars 159 forks source link

DataLoader type mismatch #11

Open AOccLib opened 5 years ago

AOccLib commented 5 years ago

Hi, I am trying to run your code but I'm getting an error relating to the DataLoader. I have a pre-trained model, model1.pth, and a DataLoader object data.valid_dl.dl. This is the validation data I used for training.

The DataLoader seems correct to me. Here are some checks:

type(data.valid_dl.dl)
torch.utils.data.dataloader.DataLoader

valid.dl.dataset
LabelList (2083 items)
x: ImageList
Image (3, 227, 227),Image (3, 227, 227),Image (3, 227, 227),Image (3, 227, 227),Image (3, 227, 227)
y: CategoryList
bempty,bempty,bempty,bempty,bempty
Path: ...

I'm trying to run this within fastAI, here's what I was trying to do.

from temperature_scaling import ModelWithTemperature

# Load data
data = ImageDataBunch.from_folder(path)

# Train
learn = cnn_learner(data, models.resnet34, metrics=accuracy)
learn.fit_one_cycle(10)
learn.save('model1')

orig_model = ".../model1.pth" # create an uncalibrated model somehow
valid_loader = data.valid_dl.dl # Create a DataLoader from the SAME VALIDATION SET used to train orig_model

scaled_model = ModelWithTemperature(orig_model)
scaled_model.set_temperature(valid_loader)

I get the following error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-d4da6fd75710> in <module>()
      6 
      7 scaled_model = ModelWithTemperature(orig_model)
----> 8 scaled_model.set_temperature(valid_loader)

/content/temperature_scaling/temperature_scaling.py in set_temperature(self, valid_loader)
     46             for input, label in valid_loader:
     47                 input = input.cuda()
---> 48                 logits = self.model(input)
     49                 logits_list.append(logits)
     50                 labels_list.append(label)

TypeError: 'str' object is not callable

I'm probably misunderstanding what valid_loader has to contain. Any help would be more than welcome.

I've also started a thread about calibration in the fastai forums, citing your paper, talk and code, so if you would rather reply there to reach a larger audience, that'd be great. Here's a link to the thread:

https://forums.fast.ai/t/calibrating-your-network-to-get-better-probability-estimates/41425

Thank you!