google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.26k stars 503 forks source link

`torch.load` without `weights_only` parameter is unsafe #1

Closed kit1980 closed 7 months ago

kit1980 commented 7 months ago

This is found via https://github.com/pytorch-labs/torchfix/

torch.load without weights_only parameter is unsafe. Explicitly set weights_only to False only if you trust the data you load and full pickle functionality is needed, otherwise set weights_only=True.

gemma/model.py:562:13

--- /home/sdym/repos/google/gemma_pytorch/gemma/model.py
+++ /home/sdym/repos/google/gemma_pytorch/gemma/model.py
@@ -557,9 +557,9 @@
         # If a string was provided as input, return a string as output.
         return results[0] if is_str_prompt else results

     def load_weights(self, model_path: str):
         self.load_state_dict(
-            torch.load(model_path, mmap=True)['model_state_dict'],
+            torch.load(model_path, mmap=True, weights_only=True)['model_state_dict'],
             strict=False,
         )

gemma/model_xla.py:517:22

--- /home/sdym/repos/google/gemma_pytorch/gemma/model_xla.py
+++ /home/sdym/repos/google/gemma_pytorch/gemma/model_xla.py
@@ -512,11 +512,11 @@
             top_ks=top_ks,
         )
         return next_tokens

     def load_weights(self, model_path: str):
-        checkpoint = torch.load(model_path)
+        checkpoint = torch.load(model_path, weights_only=True)
         model_state_dict = checkpoint['model_state_dict']

         num_attn_heads = self.config.num_attention_heads
         num_kv_heads = self.config.num_key_value_heads
         head_dim = self.config.head_dim
pengchongjin commented 7 months ago

Thanks for reporting it. We are working on it.

michaelmoynihan commented 7 months ago

Hi Sergii, Thanks for the suggestion! We have updated accordingly. I tested it out locally and it works and I have created and merged a PR.