gpleiss / temperature_scaling

A simple way to calibrate your neural network.
MIT License
1.09k stars 159 forks source link

my simple correction for reproducing good performance #20

Open hyunjunChhoi opened 4 years ago

hyunjunChhoi commented 4 years ago

(1) adding model.eval() model.eval() model.set_temperature(valid_loader) (before) Before temperature - NLL: 2.441, ECE: 0.247 Optimal temperature: 1386304.250 After temperature - NLL: 4.605, ECE: 0.580

(after) Before temperature - NLL: 2.249, ECE: 0.234 Optimal temperature: 5.823 After temperature - NLL: 1.572, ECE: 0.271

(2)adding optimizer.zero_grad() def eval(): loss = nll_criterion(self.temperature_scale(logits), labels) optimizer.zero_grad()

then, Before temperature - NLL: 2.249, ECE: 0.234 Optimal temperature: 1.800 After temperature - NLL: 1.460, ECE: 0.139

(3)change max iter 40(default) to 10000

Before temperature - NLL: 2.249, ECE: 0.234 Optimal temperature: 3.035 After temperature - NLL: 1.246, ECE: 0.016

(4)change nll_loss to ece_loss (max iter 10000) def eval():

loss = nll_criterion(self.temperature_scale(logits), labels)

        loss = ece_criterion(self.temperature_scale(logits), labels)
        optimizer.zero_grad()

Before temperature - NLL: 2.249, ECE: 0.234 Optimal temperature: 2.981 After temperature - NLL: 1.247, ECE: 0.011

I think it reproduces similar performance with the original paper

maybe it depends on hyperparameter

any other issues ?? I don't know whether it is the right correction or not

hyunjunChhoi commented 4 years ago

On the test data: I got Test Accuracy of the model on the test images: 66.1 % Test Accuracy of the model on the test images: 66.1 % before temperature - NLL: 2.199, ECE: 0.227 After temperature - NLL: 1.216, ECE: 0.012

ltong1130ztr commented 4 years ago

I think adding model.eval() in the ModelWithTemperature() class definition is a safer way, some flat classifier has batch normalization or dropout layers, if someone loading these type of model without turning it into eval mode, then the output raw logits are corrupted. It would be something like this:

class ModelWithTemperature(nn.Module):
    """
    A thin decorator, which wraps a model with temperature scaling
    model (nn.Module):
        A classification neural network
        NB: Output of the neural network should be the classification logits,
            NOT the softmax (or log softmax)!
    """
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.model.eval()
        self.temperature = nn.Parameter(torch.ones(1) * 1.5)