Closed sfluegel05 closed 1 month ago
We can fix the issue by explicitly setting the weights_only
parameter to False
when using torch.load
.
In our project, torch.load
is primarily used to:
I’ve checked the internal code of torch.load
in the current PyTorch version (2.4.1), and if no value is provided for weights_only
, it defaults to False
internally.
Relevant code:
https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1280#L1281
Additionally, I’ve created a program to verify the equality between objects/data/models loaded with no value for weights_only
and when the parameter is explicitly set to False
. This ensures no unintended behavior or consequences in our pipeline.
import numpy as np
def compare_dicts(dict1, dict2):
"""
Compare two dictionaries, handling ndarray comparison for the 'labels' key.
Args:
dict1 (dict): First dictionary to compare.
dict2 (dict): Second dictionary to compare.
Returns:
bool: True if dictionaries are equal, False otherwise.
"""
if dict1.keys() != dict2.keys():
return False
for key in dict1:
# Handle 'labels' key which contains ndarray
if key == 'labels':
if not np.array_equal(dict1[key], dict2[key]):
return False
elif isinstance(dict1[key], torch.Tensor) and isinstance(dict2[key], torch.Tensor):
if not compare_tensors(dict1[key], dict2[key]):
return False
elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
return compare_dicts(dict1[key], dict2[key])
else:
# For other keys, perform normal comparison
if dict1[key] != dict2[key]:
return False
return True
def compare_list_of_dicts(list1, list2):
"""
Compare two lists of dictionaries, handling ndarray comparison in the 'labels' key.
Args:
list1 (list): First list of dictionaries to compare.
list2 (list): Second list of dictionaries to compare.
Returns:
bool: True if both lists of dictionaries are equal, False otherwise.
"""
if len(list1) != len(list2):
return False
for dict1, dict2 in zip(list1, list2):
if not compare_dicts(dict1, dict2):
return False
return True
def compare_tensors(tensor1, tensor2):
"""Helper function to compare PyTorch tensors."""
return torch.equal(tensor1, tensor2)
def compare_models(model1, model2):
"""
Compare two models loaded from torch.load.
Can handle both dictionaries and nn.Module state_dicts.
Args:
model1: First loaded model.
model2: Second loaded model.
Returns:
bool: True if the models are equal, False otherwise.
"""
if isinstance(model1, dict) and isinstance(model2, dict):
return compare_dicts(model1, model2)
elif isinstance(model1, torch.nn.Module) and isinstance(model2, torch.nn.Module):
# Compare the state_dict of models
return compare_dicts(model1.state_dict(), model2.state_dict())
return False
if __name__ == "__main__":
import torch
data_with_no_param = torch.load("data/chebi_v200/ChEBI50/processed/smiles_token/data.pt")
data_with_false = torch.load("data/chebi_v200/ChEBI50/processed/smiles_token/data.pt", weights_only=False)
# Compare the lists of dictionaries
are_equal = compare_list_of_dicts(data_with_no_param, data_with_false)
print(f"Data comparison result (no_param vs false): {are_equal}")
model_with_no_param = torch.load("G:/github-aditya0by0/electra_pretrained.ckpt", map_location="cpu")
model_with_false = torch.load("G:/github-aditya0by0/electra_pretrained.ckpt", weights_only=False, map_location="cpu")
# model_with_true = torch.load("G:/github-aditya0by0/electra_pretrained.ckpt", weights_only=True, map_location="cpu")
# Print results
print("model_with_no_param == model_with_false:", compare_models(model_with_no_param, model_with_false))
# print("model_with_no_param == model_with_true:", torch.equal(model_with_no_param, model_with_true))
# print("model_with_false == model_with_true:", torch.equal(model_with_false, model_with_true))
Output:
Data comparison result (no_param vs false): True
model_with_no_param == model_with_false: True
If torch use False
as the default anyway, we should go ahead and set False
explicitly.
With the latest Pytorch version (2.4.0), I get the following warning:
We should investigate how to fix this (in the best case: simply be explictly setting the
weights_only
parameter toFalse
) and if this has any consequences for our pipeline.