microsoft / DeepSpeedExamples

Example models using DeepSpeed
Apache License 2.0
6.08k stars 1.04k forks source link

DeepSpeed-Chat step1 SFT evaluation error: size mismatch #280

Closed M1n9X closed 1 year ago

M1n9X commented 1 year ago

Hi,

I tried to reproduce the whole process on a 8xV100 server with following command:

python train.py --actor-model facebook/opt-13b --reward-model facebook/opt-350m --num-gpus 8

After successfully finetuning the model in step 1, i tried to evaluate the model, but an error occurred:

RuntimeError: Error(s) in loading state_dict for OPTForCausalLM:
        size mismatch for model.decoder.embed_tokens.weight: copying a param with shape torch.Size([50272, 2048]) from checkpoint, the shape in current model is torch.Size([50265, 2048]).
        size mismatch for lm_head.weight: copying a param with shape torch.Size([50272, 2048]) from checkpoint, the shape in current model is torch.Size([50265, 2048]).
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

Then i tried different configurations, but the error still exists.

After adding ignore_mismatched_sizes=True in the model from_pretrained method, the evaluation

this time, the same error occurred for above two failure cases

  /usr/local/lib/python3.8/dist-packages/transformers/models/auto/auto_factory.py:471 in           │
│ from_pretrained                                                                                  │
│                                                                                                  │
│   468 │   │   │   )                                                                              │
│   469 │   │   elif type(config) in cls._model_mapping.keys():                                    │
│   470 │   │   │   model_class = _get_model_class(config, cls._model_mapping)                     │
│ ❱ 471 │   │   │   return model_class.from_pretrained(                                            │
│   472 │   │   │   │   pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs,   │
│   473 │   │   │   )                                                                              │
│   474 │   │   raise ValueError(                                                                  │
│                                                                                                  │
│ /usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:2709 in from_pretrained    │
│                                                                                                  │
│   2706 │   │   │   │   mismatched_keys,                                                          │
│   2707 │   │   │   │   offload_index,                                                            │
│   2708 │   │   │   │   error_msgs,                                                               │
│ ❱ 2709 │   │   │   ) = cls._load_pretrained_model(                                               │
│   2710 │   │   │   │   model,                                                                    │
│   2711 │   │   │   │   state_dict,                                                               │
│   2712 │   │   │   │   loaded_state_dict_keys,  # XXX: rename?                                   │
│                                                                                                  │
│ /usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:3027 in                    │
│ _load_pretrained_model                                                                           │
│                                                                                                  │
│   3024 │   │   │   │                                                                             │
│   3025 │   │   │   │   # Mistmatched keys contains tuples key/shape1/shape2 of weights in the c  │
│   3026 │   │   │   │   # matching the weights in the model.                                      │
│ ❱ 3027 │   │   │   │   mismatched_keys += _find_mismatched_keys(                                 │
│   3028 │   │   │   │   │   state_dict,                                                           │
│   3029 │   │   │   │   │   model_state_dict,                                                     │
│   3030 │   │   │   │   │   original_loaded_keys,                                                 │
│                                                                                                  │
│ /usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:2950 in                    │
│ _find_mismatched_keys                                                                            │
│                                                                                                  │
│   2947 │   │   │   │   │                                                                         │
│   2948 │   │   │   │   │   if (                                                                  │
│   2949 │   │   │   │   │   │   model_key in model_state_dict                                     │
│ ❱ 2950 │   │   │   │   │   │   and state_dict[checkpoint_key].shape != model_state_dict[model_k  │
│   2951 │   │   │   │   │   ):                                                                    │
│   2952 │   │   │   │   │   │   mismatched_keys.append(                                           │
│   2953 │   │   │   │   │   │   │   (checkpoint_key, state_dict[checkpoint_key].shape, model_sta  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 'decoder.layers.15.fc1.bias'

Packages' versions:

BTW, i noticed some info for step 2 about --num_padding_at_beginning argument, is there anything similar in step 1, e.g. padding or something else that could result in the size mismatch error (50272 vs 50265)?

So how should i evaluate the sft model? Thanks in advance.

aopolin-lv commented 1 year ago

I met this error, too. The error is caused by the change of config.vocab_size. Just reload the config again before initializing the model_finetuned. You can copy line 213 and paste it before line 219.

M1n9X commented 1 year ago

I met this error, too. The error is caused by the change of config.vocab_size. Just reload the config again before initializing the model_finetuned. You can copy line 213 and paste it before line 219.

Thanks for the info. Now it works. Based on your suggestion, i find that the root cause is in the get_model method, i use a variable to hold the original value of config.vocab_size and reset it after tokenizer related operation.