ChEB-AI / python-chebai

GNU Affero General Public License v3.0
12 stars 4 forks source link

FutureWarning: You are using `torch.load` with `weights_only=False` #51

Closed sfluegel05 closed 1 month ago

sfluegel05 commented 2 months ago

With the latest Pytorch version (2.4.0), I get the following warning:

C:\Users\Simon Flügel\Desktop\chebai\chebai\models\electra.py:259: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
model_dict = torch.load(fin, map_location=self.device)

We should investigate how to fix this (in the best case: simply be explictly setting the weights_only parameter to False) and if this has any consequences for our pipeline.

aditya0by0 commented 1 month ago

We can fix the issue by explicitly setting the weights_only parameter to False when using torch.load.

In our project, torch.load is primarily used to:

I’ve checked the internal code of torch.load in the current PyTorch version (2.4.1), and if no value is provided for weights_only, it defaults to False internally.

Relevant code:
https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1280#L1281

Additionally, I’ve created a program to verify the equality between objects/data/models loaded with no value for weights_only and when the parameter is explicitly set to False. This ensures no unintended behavior or consequences in our pipeline.

import numpy as np

def compare_dicts(dict1, dict2):
    """
    Compare two dictionaries, handling ndarray comparison for the 'labels' key.

    Args:
        dict1 (dict): First dictionary to compare.
        dict2 (dict): Second dictionary to compare.

    Returns:
        bool: True if dictionaries are equal, False otherwise.
    """
    if dict1.keys() != dict2.keys():
        return False

    for key in dict1:
        # Handle 'labels' key which contains ndarray
        if key == 'labels':
            if not np.array_equal(dict1[key], dict2[key]):
                return False

        elif isinstance(dict1[key], torch.Tensor) and isinstance(dict2[key], torch.Tensor):
            if not compare_tensors(dict1[key], dict2[key]):
                return False

        elif isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
            return compare_dicts(dict1[key], dict2[key])

        else:
            # For other keys, perform normal comparison
            if dict1[key] != dict2[key]:
                return False

    return True

def compare_list_of_dicts(list1, list2):
    """
    Compare two lists of dictionaries, handling ndarray comparison in the 'labels' key.

    Args:
        list1 (list): First list of dictionaries to compare.
        list2 (list): Second list of dictionaries to compare.

    Returns:
        bool: True if both lists of dictionaries are equal, False otherwise.
    """
    if len(list1) != len(list2):
        return False

    for dict1, dict2 in zip(list1, list2):
        if not compare_dicts(dict1, dict2):
            return False

    return True

def compare_tensors(tensor1, tensor2):
    """Helper function to compare PyTorch tensors."""
    return torch.equal(tensor1, tensor2)

def compare_models(model1, model2):
    """
    Compare two models loaded from torch.load.
    Can handle both dictionaries and nn.Module state_dicts.

    Args:
        model1: First loaded model.
        model2: Second loaded model.

    Returns:
        bool: True if the models are equal, False otherwise.
    """
    if isinstance(model1, dict) and isinstance(model2, dict):
        return compare_dicts(model1, model2)

    elif isinstance(model1, torch.nn.Module) and isinstance(model2, torch.nn.Module):
        # Compare the state_dict of models
        return compare_dicts(model1.state_dict(), model2.state_dict())

    return False

if __name__ == "__main__":
    import torch
    data_with_no_param = torch.load("data/chebi_v200/ChEBI50/processed/smiles_token/data.pt")
    data_with_false = torch.load("data/chebi_v200/ChEBI50/processed/smiles_token/data.pt", weights_only=False)
    # Compare the lists of dictionaries
    are_equal = compare_list_of_dicts(data_with_no_param, data_with_false)
    print(f"Data comparison result (no_param vs false): {are_equal}")

    model_with_no_param = torch.load("G:/github-aditya0by0/electra_pretrained.ckpt", map_location="cpu")
    model_with_false = torch.load("G:/github-aditya0by0/electra_pretrained.ckpt", weights_only=False, map_location="cpu")
    # model_with_true = torch.load("G:/github-aditya0by0/electra_pretrained.ckpt", weights_only=True, map_location="cpu")

    # Print results
    print("model_with_no_param == model_with_false:", compare_models(model_with_no_param, model_with_false))
    # print("model_with_no_param == model_with_true:", torch.equal(model_with_no_param, model_with_true))
    # print("model_with_false == model_with_true:", torch.equal(model_with_false, model_with_true))

Output:

Data comparison result (no_param vs false): True
model_with_no_param == model_with_false: True
sfluegel05 commented 1 month ago

If torch use False as the default anyway, we should go ahead and set False explicitly.