Vahe1994 / SpQR

Apache License 2.0
515 stars 40 forks source link

LLaMa 30B loading error #43

Closed DavidePaglieri closed 7 months ago

DavidePaglieri commented 7 months ago

Hi, I'm trying to test this on the LLaMa 30b model, however I get the following error:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│                                                                              │
│ /home/user/SpQR/main.py:577 in <module>                                  │
│                                                                              │
│   574 │   device = "cuda" if torch.cuda.is_available() else "cpu"            │
│   575 │                                                                      │
│   576 │   print("============  Loading model... ============")               │
│ ❱ 577 │   model = get_model(args.model_path, args.load, args.dtype).train(Fa │
│   578 │                                                                      │
│   579 │   print("\n============ Quantizing model... ============")           │
│   580 │   if args.wbits < 16 and args.load:                                  │
│ /home/user/SpQR/modelutils.py:45 in get_model                            │
│                                                                              │
│    42 │   │   │   model = load_quantized_model(model, load_quantized)        │
│    43 │   │   else:                                                          │
│    44 │   │   │   print("Loading pretrained model ...")                      │
│ ❱  45 │   │   │   model = AutoModelForCausalLM.from_pretrained(              │
│    46 │   │   │   │   pretrained_model_name_or_path=model_path,              │
│    47 │   │   │   │   trust_remote_code=True,                                │
│    48 │   │   │   │   torch_dtype=dtype,                                     │
│                                                                              │
│ /home/user/.local/lib/python3.9/site-packages/transformers/models/auto/a │
│ uto_factory.py:467 in from_pretrained                                        │
│                                                                              │
│   464 │   │   │   )                                                          │
│   465 │   │   elif type(config) in cls._model_mapping.keys():                │
│   466 │   │   │   model_class = _get_model_class(config, cls._model_mapping) │
│ ❱ 467 │   │   │   return model_class.from_pretrained(                        │
│   468 │   │   │   │   pretrained_model_name_or_path, *model_args, config=con │
│   469 │   │   │   )                                                          │
│   470 │   │   raise ValueError(                                              │
│                                                                              │
│ /home/user/.local/lib/python3.9/site-packages/transformers/modeling_util │
│ s.py:2777 in from_pretrained                                                 │
│                                                                              │
│   2774 │   │   │   │   mismatched_keys,                                      │
│   2775 │   │   │   │   offload_index,                                        │
│   2776 │   │   │   │   error_msgs,                                           │
│ ❱ 2777 │   │   │   ) = cls._load_pretrained_model(                           │
│   2778 │   │   │   │   model,                                                │
│   2779 │   │   │   │   state_dict,                                           │
│   2780 │   │   │   │   loaded_state_dict_keys,  # XXX: rename?               │
│                                                                              │
│ /home/user/.local/lib/python3.9/site-packages/transformers/modeling_util │
│ s.py:3104 in _load_pretrained_model                                          │
│                                                                              │
│   3101 │   │   │   │   # Skip the load for shards that only contain disk-off │
│   3102 │   │   │   │   if shard_file in disk_only_shard_files:               │
│   3103 │   │   │   │   │   continue                                          │
│ ❱ 3104 │   │   │   │   state_dict = load_state_dict(shard_file)              │
│   3105 │   │   │   │                                                         │
│   3106 │   │   │   │   # Mistmatched keys contains tuples key/shape1/shape2  │
│   3107 │   │   │   │   # matching the weights in the model.                  │
│                                                                              │
│ /home/user/.local/lib/python3.9/site-packages/transformers/modeling_util │
│ s.py:444 in load_state_dict                                                  │
│                                                                              │
│    441 │   │   │   raise NotImplementedError(                                │
│    442 │   │   │   │   f"Conversion from a {metadata['format']} safetensors  │
│    443 │   │   │   )                                                         │
│ ❱  444 │   │   return safe_load_file(checkpoint_file)                        │
│    445 │   try:                                                              │
│    446 │   │   return torch.load(checkpoint_file, map_location="cpu")        │
│    447 │   except Exception as e:                                            │
│                                                                              │
│ /home/user/.local/lib/python3.9/site-packages/safetensors/torch.py:101   │
│ in load_file                                                                 │
│                                                                              │
│    98 │   result = {}                                                        │
│    99 │   with safe_open(filename, framework="pt", device=device) as f:      │
│   100 │   │   for k in f.keys():                                             │
│ ❱ 101 │   │   │   result[k] = f.get_tensor(k)                                │
│   102 │   return result                                                      │
│   103                                                                        │
│   104                                                                        │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: shape '[6656, 6656]' is invalid for input of size 33697155

I am running the same command as in the README:

python3 main.py $MODEL_PATH $DATASET \
    --wbits 4 \
    --groupsize 16 \
    --perchannel \
    --qq_scale_bits 3 \
    --qq_zero_bits 3 \
    --qq_groupsize 16 \
    --outlier_threshold=0.2 \
    --permutation_order act_order \
    --percdamp 1e0 \
    --offload_activations \
    --nsamples 4

Any ideas how to fix this?

DavidePaglieri commented 7 months ago

Looks like it might have been a problem with the safetensor file. Downloading it again solved it.