set_temperature() will not only change the temperature setting (desired),
but also the parameters (running_mean, running_mean) of every BatchNorm in the network (undesired).
I think the whole network should be frozen except the temperature scaling.
how to fix
set BatchNorm momentum to 0 before set_temperature()
(snippet tested on another codebase, not this one)
def freeze_batch_norm(self):
for module in self.modules():
if isinstance(module, torch.nn.BatchNorm2d):
module.momentum = 0
In case it could be usefull to someone: my script to diff 2 model checkpoints
auto generated with LLMs,
tested on this codebase
import torch
def load_checkpoint(filepath):
"""Load a PyTorch model checkpoint."""
return torch.load(filepath, map_location=torch.device('cpu'))
def normalize_key(key):
"""Normalize dictionary keys by removing specific prefixes."""
prefix = "model."
if key.startswith(prefix):
return key[len(prefix):]
return key
def compare_tensors(tensor1, tensor2, prefix, changed_layers):
"""Helper function to compare two tensors."""
if not torch.equal(tensor1, tensor2):
changed_layers.append(prefix)
def compare_dicts(dict1, dict2, prefix, changed_layers):
"""Recursively compare dictionaries that may contain tensors."""
# Normalize the keys in both dictionaries
normalized_dict1 = {normalize_key(k): v for k, v in dict1.items()}
normalized_dict2 = {normalize_key(k): v for k, v in dict2.items()}
for key in normalized_dict1.keys():
if key in normalized_dict2:
if isinstance(normalized_dict1[key], torch.Tensor) and isinstance(normalized_dict2[key], torch.Tensor):
compare_tensors(normalized_dict1[key], normalized_dict2[key], f"{prefix}.{key}" if prefix else key, changed_layers)
elif isinstance(normalized_dict1[key], dict) and isinstance(normalized_dict2[key], dict):
compare_dicts(normalized_dict1[key], normalized_dict2[key], f"{prefix}.{key}" if prefix else key, changed_layers)
else:
print(f"Type mismatch at {prefix}.{key}")
else:
print(f"Key {key} found in the first model but not in the second at {prefix}")
for key in normalized_dict2.keys():
if key not in normalized_dict1:
print(f"Key {key} found in the second model but not in the first at {prefix}")
def compare_models(checkpoint1, checkpoint2):
"""Compare two model checkpoints."""
model1 = load_checkpoint(checkpoint1)
model2 = load_checkpoint(checkpoint2)
changed_layers = []
compare_dicts(model1, model2, "", changed_layers)
return changed_layers
# Paths to your model checkpoints
checkpoint_path1 = '/home/<yourpath>/model_with_temperature.pth'
checkpoint_path2 = '/home/<yourpath>/model_without_temp.pth'
# Compare the models
changed_layers = compare_models(checkpoint_path1, checkpoint_path2)
if changed_layers:
print("Changed layers:")
for layer in changed_layers:
print(layer)
else:
print("No changes found between the models.")
summary
set_temperature()
will not only change the temperature setting (desired), but also the parameters (running_mean, running_mean) of every BatchNorm in the network (undesired). I think the whole network should be frozen except the temperature scaling.how to fix
set BatchNorm momentum to 0 before
set_temperature()
(snippet tested on another codebase, not this one)In case it could be usefull to someone: my script to diff 2 model checkpoints
auto generated with LLMs,
tested on this codebase
output of above script