huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.22k stars 26.34k forks source link

rewrite state_dict in self.model.save_pretrained(), causing the '_metadata' it saved to be missing. #14268

Closed changwangss closed 2 years ago

changwangss commented 2 years ago

Environment info

Who can help

function: self.model.save_pretrained() in trainer.py @sgugger root cause: the rewrite state_dict code in modeling_utils.py added by @stas00 in PR(#8737) to ignore keys

Information

I am using Helsinki-NLP/opus-mt-en-ro in translation task and make it quantized with intel neural compressor(version 1.7).

I would load it from a pre-trained model, fine-tune it, quantize it, then save its state_dict. The issue happens when saving and reloading this quantized version.

When DynamicQuantizedLinear generates keys, nn.quantized.Linear uses this format: model.encoder.layers.0.self_attn.k_proj._packed_params._packed_params corresponding version=3, but by using trainer.save_model() to save it to version= 1 due to missing _metadata. it will cause the quantized model reload failed. For more information about version, you can see here in pytorch repo.

    # Version 1
    #   self
    #   |--- weight : Tensor
    #   |--- bias : Tensor
    #
    # Version 2
    #   self
    #   |--- weight : Tensor
    #   |--- bias : Tensor
    #   |--- dtype : torch.dtype
    #
    # Version 3
    #   self
    #   |--- _packed_params : (Tensor, Tensor) representing (weight, bias)
    #                         of LinearPackedParams
    #   |--- dtype : torch.dtype

we found that the root cause is to rewrite state_dict in order to ignore keys, resulting in missing _metadata information which related with version choose.

code link: https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L1052

To reproduce

Steps to reproduce the behavior:

  1. load a pre-trained model Helsinki-NLP/opus-mt-en-ro , fine-tune it, quantize it with dynamic,
  2. save the quantized model and Load it again, you will get an error.

    error

    
    File "/home2/changwa1/anaconda3/envs/inc_example/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1388, in load
    state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
    File "/home2/changwa1/anaconda3/envs/inc_example/lib/python3.6/site-packages/torch/nn/quantized/dynamic/modules/linear.py", line 72, in _load_from_state_dict
    missing_keys, unexpected_keys, error_msgs)
    File "/home2/changwa1/anaconda3/envs/inc_example/lib/python3.6/site-packages/torch/nn/quantized/modules/linear.py", line 220, in _load_from_state_dict
    weight = state_dict.pop(prefix + 'weight')
    KeyError: 'model.encoder.layers.0.self_attn.k_proj.weight'
3. modify the code as following that remove unexpceted keys from state_dict directly instead of rewriting. you will success reload.
### modify
the [rewrite state_dict code](https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L1052) in modeling_utils.py  line 1052.
`origin`
    if self._keys_to_ignore_on_save is not None:
        state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
`change`
    if self._keys_to_ignore_on_save is not None:
        for item in self._keys_to_ignore_on_save:
            del state_dict[item]

<!-- If you have code snippets, error messages, stack traces please provide them here as well.
     Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
     Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.-->

## Expected behavior

<!-- A clear and concise description of what you would expect to happen. -->

You can modify it as I mentioned, it will be better if you have a more effective solution.
sgugger commented 2 years ago

I think your solution is very good, to avoid deleting that _metadata attribute of the state dict, would you like to make a PR out of it, since you found the fix?

changwangss commented 2 years ago

PR has been committed, review please. @sgugger