gpleiss / temperature_scaling

A simple way to calibrate your neural network.
MIT License
1.09k stars 159 forks source link

Batchnorm params are modified, they should not ! #36

Open joihn opened 4 months ago

joihn commented 4 months ago

summary

set_temperature() will not only change the temperature setting (desired), but also the parameters (running_mean, running_mean) of every BatchNorm in the network (undesired). I think the whole network should be frozen except the temperature scaling.

how to fix

set BatchNorm momentum to 0 before set_temperature() (snippet tested on another codebase, not this one)

    def freeze_batch_norm(self):
        for module in self.modules():
            if isinstance(module, torch.nn.BatchNorm2d):
                module.momentum = 0

In case it could be usefull to someone: my script to diff 2 model checkpoints

auto generated with LLMs,
tested on this codebase

import torch

def load_checkpoint(filepath):
    """Load a PyTorch model checkpoint."""
    return torch.load(filepath, map_location=torch.device('cpu'))

def normalize_key(key):
    """Normalize dictionary keys by removing specific prefixes."""
    prefix = "model."
    if key.startswith(prefix):
        return key[len(prefix):]
    return key

def compare_tensors(tensor1, tensor2, prefix, changed_layers):
    """Helper function to compare two tensors."""
    if not torch.equal(tensor1, tensor2):
        changed_layers.append(prefix)

def compare_dicts(dict1, dict2, prefix, changed_layers):
    """Recursively compare dictionaries that may contain tensors."""
    # Normalize the keys in both dictionaries
    normalized_dict1 = {normalize_key(k): v for k, v in dict1.items()}
    normalized_dict2 = {normalize_key(k): v for k, v in dict2.items()}

    for key in normalized_dict1.keys():
        if key in normalized_dict2:
            if isinstance(normalized_dict1[key], torch.Tensor) and isinstance(normalized_dict2[key], torch.Tensor):
                compare_tensors(normalized_dict1[key], normalized_dict2[key], f"{prefix}.{key}" if prefix else key, changed_layers)
            elif isinstance(normalized_dict1[key], dict) and isinstance(normalized_dict2[key], dict):
                compare_dicts(normalized_dict1[key], normalized_dict2[key], f"{prefix}.{key}" if prefix else key, changed_layers)
            else:
                print(f"Type mismatch at {prefix}.{key}")
        else:
            print(f"Key {key} found in the first model but not in the second at {prefix}")

    for key in normalized_dict2.keys():
        if key not in normalized_dict1:
            print(f"Key {key} found in the second model but not in the first at {prefix}")

def compare_models(checkpoint1, checkpoint2):
    """Compare two model checkpoints."""
    model1 = load_checkpoint(checkpoint1)
    model2 = load_checkpoint(checkpoint2)

    changed_layers = []
    compare_dicts(model1, model2, "", changed_layers)

    return changed_layers

# Paths to your model checkpoints
checkpoint_path1 = '/home/<yourpath>/model_with_temperature.pth'
checkpoint_path2 = '/home/<yourpath>/model_without_temp.pth'

# Compare the models
changed_layers = compare_models(checkpoint_path1, checkpoint_path2)
if changed_layers:
    print("Changed layers:")
    for layer in changed_layers:
        print(layer)
else:
    print("No changes found between the models.")

output of above script

Key temperature found in the first model but not in the second at 
Changed layers:
features.denseblock1.denselayer1.norm1.running_mean
features.denseblock1.denselayer1.norm1.running_var
features.denseblock1.denselayer1.norm1.num_batches_tracked
features.denseblock1.denselayer1.norm2.running_mean
features.denseblock1.denselayer1.norm2.running_var
features.denseblock1.denselayer1.norm2.num_batches_tracked
features.denseblock1.denselayer2.norm1.running_mean
features.denseblock1.denselayer2.norm1.running_var
features.denseblock1.denselayer2.norm1.num_batches_tracked
features.denseblock1.denselayer2.norm2.running_mean
features.denseblock1.denselayer2.norm2.running_var
features.denseblock1.denselayer2.norm2.num_batches_tracked
features.denseblock1.denselayer3.norm1.running_mean
features.denseblock1.denselayer3.norm1.running_var
features.denseblock1.denselayer3.norm1.num_batches_tracked
features.denseblock1.denselayer3.norm2.running_mean
features.denseblock1.denselayer3.norm2.running_var
features.denseblock1.denselayer3.norm2.num_batches_tracked
features.denseblock1.denselayer4.norm1.running_mean
features.denseblock1.denselayer4.norm1.running_var
features.denseblock1.denselayer4.norm1.num_batches_tracked
features.denseblock1.denselayer4.norm2.running_mean
features.denseblock1.denselayer4.norm2.running_var
features.denseblock1.denselayer4.norm2.num_batches_tracked
features.denseblock1.denselayer5.norm1.running_mean
features.denseblock1.denselayer5.norm1.running_var
features.denseblock1.denselayer5.norm1.num_batches_tracked
features.denseblock1.denselayer5.norm2.running_mean
features.denseblock1.denselayer5.norm2.running_var
features.denseblock1.denselayer5.norm2.num_batches_tracked
features.denseblock1.denselayer6.norm1.running_mean
features.denseblock1.denselayer6.norm1.running_var
features.denseblock1.denselayer6.norm1.num_batches_tracked
features.denseblock1.denselayer6.norm2.running_mean
features.denseblock1.denselayer6.norm2.running_var
features.denseblock1.denselayer6.norm2.num_batches_tracked
features.transition1.norm.running_mean
features.transition1.norm.running_var
features.transition1.norm.num_batches_tracked
features.denseblock2.denselayer1.norm1.running_mean
features.denseblock2.denselayer1.norm1.running_var
features.denseblock2.denselayer1.norm1.num_batches_tracked
features.denseblock2.denselayer1.norm2.running_mean
features.denseblock2.denselayer1.norm2.running_var
features.denseblock2.denselayer1.norm2.num_batches_tracked
features.denseblock2.denselayer2.norm1.running_mean
features.denseblock2.denselayer2.norm1.running_var
features.denseblock2.denselayer2.norm1.num_batches_tracked
features.denseblock2.denselayer2.norm2.running_mean
features.denseblock2.denselayer2.norm2.running_var
features.denseblock2.denselayer2.norm2.num_batches_tracked
features.denseblock2.denselayer3.norm1.running_mean
features.denseblock2.denselayer3.norm1.running_var
features.denseblock2.denselayer3.norm1.num_batches_tracked
features.denseblock2.denselayer3.norm2.running_mean
features.denseblock2.denselayer3.norm2.running_var
features.denseblock2.denselayer3.norm2.num_batches_tracked
features.denseblock2.denselayer4.norm1.running_mean
features.denseblock2.denselayer4.norm1.running_var
features.denseblock2.denselayer4.norm1.num_batches_tracked
features.denseblock2.denselayer4.norm2.running_mean
features.denseblock2.denselayer4.norm2.running_var
features.denseblock2.denselayer4.norm2.num_batches_tracked
features.denseblock2.denselayer5.norm1.running_mean
features.denseblock2.denselayer5.norm1.running_var
features.denseblock2.denselayer5.norm1.num_batches_tracked
features.denseblock2.denselayer5.norm2.running_mean
features.denseblock2.denselayer5.norm2.running_var
features.denseblock2.denselayer5.norm2.num_batches_tracked
features.denseblock2.denselayer6.norm1.running_mean
features.denseblock2.denselayer6.norm1.running_var
features.denseblock2.denselayer6.norm1.num_batches_tracked
features.denseblock2.denselayer6.norm2.running_mean
features.denseblock2.denselayer6.norm2.running_var
features.denseblock2.denselayer6.norm2.num_batches_tracked
features.transition2.norm.running_mean
features.transition2.norm.running_var
features.transition2.norm.num_batches_tracked
features.denseblock3.denselayer1.norm1.running_mean
features.denseblock3.denselayer1.norm1.running_var
features.denseblock3.denselayer1.norm1.num_batches_tracked
features.denseblock3.denselayer1.norm2.running_mean
features.denseblock3.denselayer1.norm2.running_var
features.denseblock3.denselayer1.norm2.num_batches_tracked
features.denseblock3.denselayer2.norm1.running_mean
features.denseblock3.denselayer2.norm1.running_var
features.denseblock3.denselayer2.norm1.num_batches_tracked
features.denseblock3.denselayer2.norm2.running_mean
features.denseblock3.denselayer2.norm2.running_var
features.denseblock3.denselayer2.norm2.num_batches_tracked
features.denseblock3.denselayer3.norm1.running_mean
features.denseblock3.denselayer3.norm1.running_var
features.denseblock3.denselayer3.norm1.num_batches_tracked
features.denseblock3.denselayer3.norm2.running_mean
features.denseblock3.denselayer3.norm2.running_var
features.denseblock3.denselayer3.norm2.num_batches_tracked
features.denseblock3.denselayer4.norm1.running_mean
features.denseblock3.denselayer4.norm1.running_var
features.denseblock3.denselayer4.norm1.num_batches_tracked
features.denseblock3.denselayer4.norm2.running_mean
features.denseblock3.denselayer4.norm2.running_var
features.denseblock3.denselayer4.norm2.num_batches_tracked
features.denseblock3.denselayer5.norm1.running_mean
features.denseblock3.denselayer5.norm1.running_var
features.denseblock3.denselayer5.norm1.num_batches_tracked
features.denseblock3.denselayer5.norm2.running_mean
features.denseblock3.denselayer5.norm2.running_var
features.denseblock3.denselayer5.norm2.num_batches_tracked
features.denseblock3.denselayer6.norm1.running_mean
features.denseblock3.denselayer6.norm1.running_var
features.denseblock3.denselayer6.norm1.num_batches_tracked
features.denseblock3.denselayer6.norm2.running_mean
features.denseblock3.denselayer6.norm2.running_var
features.denseblock3.denselayer6.norm2.num_batches_tracked
features.norm_final.running_mean
features.norm_final.running_var
features.norm_final.num_batches_tracked