huggingface / safetensors

Simple, safe way to store and distribute tensors
https://huggingface.co/docs/safetensors
Apache License 2.0
2.75k stars 185 forks source link

KeyError: torch.complex64 when attempting to save PyTorch model #450

Closed NielsRogge closed 3 months ago

NielsRogge commented 6 months ago

System Info

safetensors v0.4.2 huggingface_hub v0.22.0.dev0

Information

Reproduction

We recently switched to leveraging Safetensors by default for the PyTorchModelHubMixin class in huggingface_hub (https://github.com/huggingface/huggingface_hub/pull/2033), which is a minimal class that adds from_pretrained and push_to_hub methods to any custom nn.Module.

However, when trying out this class on the Gemma series of models by Google, I get the following error when calling push_to_hub (which first saves the tensors in the safetensors format before uploading the files to the hub):

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
[<ipython-input-8-eac8a21155c9>](https://localhost:8080/#) in <cell line: 1>()
----> 1 model.push_to_hub(f"nielsr/gemma-2b-it")

8 frames
[/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py](https://localhost:8080/#) in inner_f(*args, **kwargs)
     99                     message += "\n\n" + custom_message
    100                 warnings.warn(message, FutureWarning)
--> 101             return f(*args, **kwargs)
    102 
    103         return inner_f

[/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_validators.py](https://localhost:8080/#) in _inner_fn(*args, **kwargs)
    117             kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs)
    118 
--> 119         return fn(*args, **kwargs)
    120 
    121     return _inner_fn  # type: ignore

[/usr/local/lib/python3.10/dist-packages/huggingface_hub/hub_mixin.py](https://localhost:8080/#) in push_to_hub(self, repo_id, config, commit_message, private, token, branch, create_pr, allow_patterns, ignore_patterns, delete_patterns, api_endpoint)
    517         with SoftTemporaryDirectory() as tmp:
    518             saved_path = Path(tmp) / repo_id
--> 519             self.save_pretrained(saved_path, config=config)
    520             return api.upload_folder(
    521                 repo_id=repo_id,

[/usr/local/lib/python3.10/dist-packages/huggingface_hub/hub_mixin.py](https://localhost:8080/#) in save_pretrained(self, save_directory, config, repo_id, push_to_hub, **push_to_hub_kwargs)
    247 
    248         # save model weights/files (framework-specific)
--> 249         self._save_pretrained(save_directory)
    250 
    251         # save config (if provided)

[/usr/local/lib/python3.10/dist-packages/huggingface_hub/hub_mixin.py](https://localhost:8080/#) in _save_pretrained(self, save_directory)
    590         """Save weights from a Pytorch model to a local directory."""
    591         model_to_save = self.module if hasattr(self, "module") else self  # type: ignore
--> 592         save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
    593 
    594     @classmethod

[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in save_model(model, filename, metadata, force_contiguous)
    153     """
    154     state_dict = model.state_dict()
--> 155     to_removes = _remove_duplicate_names(state_dict)
    156 
    157     for kept_name, to_remove_group in to_removes.items():

[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in _remove_duplicate_names(state_dict, preferred_names, discard_names)
     98     to_remove = defaultdict(list)
     99     for shared in shareds:
--> 100         complete_names = set([name for name in shared if _is_complete(state_dict[name])])
    101         if not complete_names:
    102             raise RuntimeError(

[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in <listcomp>(.0)
     98     to_remove = defaultdict(list)
     99     for shared in shareds:
--> 100         complete_names = set([name for name in shared if _is_complete(state_dict[name])])
    101         if not complete_names:
    102             raise RuntimeError(

[/usr/local/lib/python3.10/dist-packages/safetensors/torch.py](https://localhost:8080/#) in _is_complete(tensor)
     79 
     80 def _is_complete(tensor: torch.Tensor) -> bool:
---> 81     return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _SIZE[tensor.dtype] == storage_size(tensor)
     82 
     83 

KeyError: torch.complex64

Here's a notebook for reproduction.

Expected behavior

This model has some tensors of type torch.complex64, would be great to save those.

github-actions[bot] commented 5 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

github-actions[bot] commented 3 months ago

This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.

npuichigo commented 2 months ago

same issue here

NielsRogge commented 2 months ago

Friendly pinging @Narsil here