facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.33k stars 6.39k forks source link

How can I save the checkpoint, of which the weights has been altered? #4664

Open robotsp opened 2 years ago

robotsp commented 2 years ago

❓ Questions and Help

Before asking:

  1. search the issues.
  2. search the docs.

What is your question?

Code

What have you tried?

What's your environment?

gmryu commented 2 years ago

You may try

from fairseq import checkpoint_utils
checkpoint_utils.torch_persistent_save(state_dict, filename, async_write=false)
print(f"Finished saving checkpoint to {filename}")

relative info: checkpoint_utils , trainer's save_checkpoint

Is this what you need?

robotsp commented 2 years ago

@gmryu Thanks. Yes, I think so. But I have tried and it raises an error when I reloaded the model:

args = state["args"] if arg_overrides is not None: for arg_name, arg_val in arg_overrides.items(): KeyError: 'args'

I think some state_dicts could be missing, when I save. My scripts:

model = TransformerModel.from_pretrained( '/path_to_model', checkpoint_file='model_name', data_name_or_path='/path_to_data' )

state_dict = model.state_dict() ## weights altered

from fairseq import checkpoint_utils checkpoint_utils.torch_persistent_save(state_dict, "new_model", async_write=false)

## reload the new model and KeyError

gmryu commented 2 years ago

@robotsp I guess that will happen since a model state dict is not a checkpoint for fairseq. I found you can use torch.save instead. The flow is you load the old model again and do load_state_dict( your edited state_dict)

import torch
from fairseq.models.transformer.transformer_legacy import TransformerModel

ddp="{directory of dict.xxx.txt}"
new_path="{the new model file path}" # note: this is not a checkpoint, it stores only model.state_dict()

model = TransformerModel.from_pretrained( '/path_to_old_model', checkpoint_file='model_name', data_name_or_path=ddp)
torch.save(model.state_dict(), new_path)

mdl2 = TransformerModel.from_pretrained( '/path_to_old_model', checkpoint_file='model_name', data_name_or_path=ddp)
# You can initialize a TransformerModel instance whose structure is the same instead.
# Loading takes a long time and tell python to create a same structure model is annoying as well. So pick your best.

mdl2.load_state_dict( torch.load(new_path) )
# It shall tell you "All keys matched successfully"

Well, one more idea is to write a new Transformer.py, which inherits the original and has a new load method. It may reduce your loading time as well. Good luck.

robotsp commented 2 years ago

@robotsp I guess that will happen since a model state dict is not a checkpoint for fairseq. I found you can use torch.save instead. The flow is you load the old model again and do load_state_dict( your edited state_dict)

import torch
from fairseq.models.transformer.transformer_legacy import TransformerModel

ddp="{directory of dict.xxx.txt}"
new_path="{the new model file path}" # note: this is not a checkpoint, it stores only model.state_dict()

model = TransformerModel.from_pretrained( '/path_to_old_model', checkpoint_file='model_name', data_name_or_path=ddp)
torch.save(model.state_dict(), new_path)

mdl2 = TransformerModel.from_pretrained( '/path_to_old_model', checkpoint_file='model_name', data_name_or_path=ddp)
# You can initialize a TransformerModel instance whose structure is the same instead.
# Loading takes a long time and tell python to create a same structure model is annoying as well. So pick your best.

mdl2.load_state_dict( torch.load(new_path) )
# It shall tell you "All keys matched successfully"

Well, one more idea is to write a new Transformer.py, which inherits the original and has a new load method. It may reduce your loading time as well. Good luck.

Looks great! How can I save mdl2 as a checkpoint after I load the state_dict? Is there any method from fairseq?

gmryu commented 2 years ago

@robotsp I guess at this point, you do not load as a model but load the checkpoint itself.

ckpt_state_dict=torch.load("{your old model checkpoint.pt}")
print(ckpt_state_dict) # it should have something like "model","cfg"
# by editing what inside "model" as tensors, you can edit your model.
# downside is you cannot treat is as a model instance

# after adjustment
torch.save(ckpt_state_dict, "{new checkpoint.pt}") # 
robotsp commented 2 years ago

I guess at this point, you do not load as a model but load the checkpoint itself.

@gmryu Thanks for your reply, I am going to finetune the new model but I cannot load and train it by using fairseq-train. Could you please provide the code?

Including the steps:

  1. load the old model and get the state_dict
  2. alter the state_dict and initialize the new model
  3. save the new model
  4. reload the new model and finetune it with fairseq-train
gmryu commented 2 years ago

@robotsp I gave the code before. Please at least tell me what happen when you tried these. I will just do this again if you did not see that or you missed something.

First, you need to understand what a fairseq checkpoint is:

>>> import torch
>>> sd=torch.load("my_checkpoint.pt")
>>> print(sd.keys())
dict_keys(['cfg', 'args', 'model', 'optimizer_history', 'extra_state', 'last_optimizer_state'])

So you know the checkpoint has above elements. Look into model you find:

>>> print(sd["model"].keys())
dict_keys(['encoder.version', 'encoder.embed_tokens.weight', 'encoder.embed_positions._float_tensor', 'encoder.layers.0.self_attn.k_proj.weight', 'encoder.layers.0.self_attn.k_proj.bias', 'encoder.layers.0.self_attn.v_proj.weight', ..... ])

Obviously, you can access those weights(tensor) by using keys. Says sd["model"]["encoder.embed_tokens.weight"] Edit them inside the dict. You do not need to create a TransformerModel or other model instance.

Once you have done with editing, you can save and load later:

>>> torch.save(sd, "tmp2.pt")
>>> exit()
python
>>> import torch
>>> sd2=torch.load("tmp2.pt")
>>> print(sd2.keys())
dict_keys(['cfg', 'args', 'model', 'optimizer_history', 'extra_state', 'last_optimizer_state'])

This tmp2.pt can be used by fairseq-train or fairseq-hydra-train. So I hope you understand what I wrote and please give out your error log if you encounter something.

robotsp commented 2 years ago

Thanks @gmryu , it works. Btw, do you know any methods of pruning (L1 pruning, etc..) fairseq models?

gmryu commented 2 years ago

@robotsp I have not pruned or distilled any models. But if pytorch provide methods to prune a module, I believe you can cast that method upon the ["model"]["your weight"]. (It might need 2 parts: 1. you load the model and prune it, save the pruned weights only by torch or numpy. 2. you load the checkpoint and updates its weight according to your saved weights.

If you succeeded in pruning models, I would be grateful if you can share it here.

robotsp commented 2 years ago

@gmryu I have a try using Microsoft NNI to prune fairseq model, and here is the code,

` import torch from torch import nn from torchvision import models from torch import optim import torchvision.transforms as transforms from torchsummary import summary import numpy as np import random import nni from nni.algorithms.compression.pytorch.pruning import LevelPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import L1NormPruner from nni.algorithms.compression.pytorch.pruning import L2FilterPruner from nni.compression.pytorch import ModelSpeedup from fairseq.models.transformer import TransformerModel

model = TransformerModel.from_pretrained( 'path_to_checkpoint', checkpoint_file='checkpoint.pt', data_name_or_path='path_to_data' )

config_list = [{'sparsity': 0.5, 'op_types':['Linear']}] pruner = L1NormPruner(model, configlist) , masks = pruner.compress() pruner._unwrap_model() ModelSpeedup(model, dummy_input=??, masks_file=masks).speedup_model() `

I checked the document of Microsoft NNI and it only provides a Conv-base demo, do you know about the dummy_input (input size) of Transformer-base model?

ref: https://nni.readthedocs.io/en/latest/tutorials/pruning_quick_start_mnist.html

gmryu commented 2 years ago

@robotsp
About that dummy_input, it is an example of acceptable input for that model. Like for blenderbot in huggingface, it wrote:

@property
    def dummy_inputs(self):
        pad_token = self.config.pad_token_id
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),
            "input_ids": input_ids,
            "decoder_input_ids": input_ids,
        }
        return dummy_inputs

And this is by itself accepted by blenderbot's forward(**dummy_input)

So for fairseq transformer, you need a data that suits TransformerModel's forward From that forward, you probably need:

src_tokens, # 2D tensor  batch_size * sentence length.   each element inside is no bigger than your vocab size 
src_lengths, # 1D tensor, each is the given sentence length 
prev_output_tokens, # 2D tensor   batch_size * sentence length.  each element inside is no bigger than your vocab size 

To prune, it probably requires a simple run through( like a evaluation), so you need to provide a data here.

Wish this helps, but I am also new to this field so let me know if this works. I may not be able to reply you too soon.

--

I do not know whether there is a recommended sentence length or batch size. If you are very serious about this, I recommend you try 2-3 different input. Says one is actually taken from your data, one is tensors of random integers.

robotsp commented 2 years ago

@gmryu I tried to assign tensors of random integers as input and it raises an error at ModelSpeedup(model, dummy_input=torch.rand([64,50,50,64,50]), masks_file=masks).speedup_model()

error msg:

NotImplementedError                       Traceback (most recent call last)
<ipython-input-53-4204d8c2effc> in <module>
----> 1 ModelSpeedup(model, dummy_input=torch.rand([64,50,50,64,50]), masks_file=masks).speedup_model()

/usr/local/lib/python3.6/site-packages/nni/compression/pytorch/speedup/compressor.py in __init__(self, model, dummy_input, masks_file, map_location, batch_dim, confidence)
     55         self.batch_dim = batch_dim
     56         self.dummy_input, self.device = self._random_model_input(dummy_input, confidence, batch_dim)
---> 57         self.torch_graph = build_module_graph(model, self.dummy_input)
     58         # dict object to save the auto inferences objects of the submodules
     59         self.auto_inferences = {}

/usr/local/lib/python3.6/site-packages/nni/common/graph_utils.py in build_module_graph(model, dummy_input)
     22 
     23 def build_module_graph(model, dummy_input):
---> 24     return TorchModuleGraph(model, dummy_input)
     25 
     26 

/usr/local/lib/python3.6/site-packages/nni/common/graph_utils.py in __init__(self, model, dummy_input, traced_model)
    250 
    251     def __init__(self, model=None, dummy_input=None, traced_model=None):
--> 252         super().__init__(model, dummy_input, traced_model)
    253         self.global_count = 0
    254         self.reused_module = set()

/usr/local/lib/python3.6/site-packages/nni/common/graph_utils.py in __init__(self, model, dummy_input, traced_model)
     64         elif model is not None and dummy_input is not None:
     65             self.bound_model = model
---> 66             self._trace(model, dummy_input)
     67         else:
     68             raise Exception(

/usr/local/lib/python3.6/site-packages/nni/common/graph_utils.py in _trace(self, model, dummy_input)
     76             # only pytorch with version greater than 1.6.0 has the strict option
     77             kw_args['strict'] = False
---> 78         self.trace = torch.jit.trace(model, dummy_input, **kw_args)
     79         torch._C._jit_pass_inline(self.trace.graph)
     80         model.train(training)

/usr/local/lib/python3.6/site-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    740             strict,
    741             _force_outplace,
--> 742             _module_class,
    743         )
    744 

/usr/local/lib/python3.6/site-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    938                 var_lookup_fn,
    939                 strict,
--> 940                 _force_outplace,
    941             )
    942             check_trace_method = module._c._get_method(method_name)

/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    723                 input = result
    724         if torch._C._get_tracing_state():
--> 725             result = self._slow_forward(*input, **kwargs)
    726         else:
    727             result = self.forward(*input, **kwargs)

/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
    707                 recording_scopes = False
    708         try:
--> 709             result = self.forward(*input, **kwargs)
    710         finally:
    711             if recording_scopes:

/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input)
    173         registered hooks while the latter silently ignores them.
    174     """
--> 175     raise NotImplementedError
    176 
    177 

NotImplementedError: 
robotsp commented 2 years ago

@gmryu Does it mean nni does not support fairseq framework yet?

gmryu commented 2 years ago

@robotsp sorry, I cannot reply you during my work time. (Also, have you actually read all I wrote? Is my phrasing too bad to understand?)

Nonetheless, you may need to know:

  1. dummy_input is literally a "dummy" batch of data to the model. It is a python dicitionary, not a single tensor.
  2. what class is the model you used here? Can it be fairseq's TransformerModel? If you are using fairseq's TransformerModel as the model, it requires a batch to be
    # if the model is a fairseq Transformer,
    dummy_input={ 
    "src_tokens":   # a 2D tensor (batch_size * sentence length) , #each element inside is no bigger than your vocab size 
    "src_lengths":  # a 1D tensor, each is the given sentence length 
    "prev_output_tokens":  # a 2D tensor (batch_size * sentence length).  each element inside is no bigger than your vocab size :
    }

If it cannot be a TransformerModel, you need to find out what its forward need or what kind of a batch that model can accept.

robotsp commented 2 years ago

@robotsp sorry, I cannot reply you during my work time. (Also, have you actually read all I wrote? Is my phrasing too bad to understand?)

Nonetheless, you may need to know:

  1. dummy_input is literally a "dummy" batch of data to the model. It is a python dicitionary, not a single tensor.
  2. what class is the model you used here? Can it be fairseq's TransformerModel? If you are using fairseq's TransformerModel as the model, it requires a batch to be
# if the model is a fairseq Transformer,
dummy_input={ 
  "src_tokens":   # a 2D tensor (batch_size * sentence length) , #each element inside is no bigger than your vocab size 
  "src_lengths":  # a 1D tensor, each is the given sentence length 
  "prev_output_tokens":  # a 2D tensor (batch_size * sentence length).  each element inside is no bigger than your vocab size :
}

If it cannot be a TransformerModel, you need to find out what its forward need or what kind of a batch that model can accept.

@gmryu , I tried using Pytorch utils nn.prune method to prune fairseq model, it can successfully set those less important weights to zeros. And I save the state_dict using torch.save. The question is, when I need to load it and retrain the model, the zeros weights will be updated again.

Here is the steps:

  1. Use `fairseq.from_pretrained`` to load the model, and prune it using nn.prune. //get new state_dict with pruned weights.
  2. Use torch.load() to get a new model and assign the new state_dict to it.
  3. Use torch.save() to save the new model with pruned weights.
  4. Use fairseq-train to retrain the model. (But the prune weights are updated again..)

How can I freeze these unstructured pruned weights and avoid its updating during retraining process?

gmryu commented 2 years ago

@robotsp if you want to freeze a tensor from updating, I guess the simple way is to freeze it in the code directly. For example, you can copy a TransformerModel implementation, freeze those weights in its build_model method and load_state_dict` method. Then you use --user-dir to import your custom model, --arch to tell fairseq to use your custom model class.

(or you may try to freeze some tensors,i.e. model weights, and reload again. It will identifies whether saving tensors keeps their frozen state or not.)

robotsp commented 2 years ago

@gmryu I have checked some weight freezing methods, but all of them are to freeze a whole module rather than some tensors (in my case, I only want to freeze those zero weights from pruning). The requires_grad influences the whole module, right?

gmryu commented 2 years ago

@robotsp check out this: https://discuss.pytorch.org/t/how-do-i-freeze-the-specific-weights-in-a-layer/104722/5

robotsp commented 2 years ago

(or you may try to freeze some tensors,i.e. model weights, and reload again. It will identifies whether saving tensors keeps their frozen state or not.)

It is similar to that pruning mechanism from torch.nn, but for fairseq model, it can not directly apply but a workaround steps:

Use `fairseq.from_pretrained`` to load the model, and prune it using nn.prune. //get new state_dict with pruned weights. Use torch.load() to get a new model and assign the new state_dict to it. Use torch.save() to save the new model with pruned weights. Use fairseq-train to retrain the model. (But the prune weights are updated again..)

Does any method I can freeze some zeros tensor after reload?

gmryu commented 2 years ago

If you want to only freeze some slice of a single weight tensor, in short answer it is difficult and you need to modify fairseq-train. The previous link talks about it.

robotsp commented 2 years ago

@gmryu , I found that params.requires_grad=False isn't working for fairseq model to freeze weights?

  1. I used torch.load("checkpoint.pt"), and reset the params.requires_grad=False.
  2. I saved it as a weights-frozen model using torch.save()
  3. then, I used fairseq-train, but the weights are still updated.

Is there any problems in my implementation?

gmryu commented 2 years ago

@robotsp Sorry, I have no confidence in these since I have not tried it yet. I guess I would need to try myself later. However, I only have time in weekend so it may take times.

If you want to solve it now, I would suggest adjusting the TransformerModel implementation. To make parts of the model freeze during fairseq-train. I believe build_model method and load_state_dict method are where you can freeze those parameters.

robotsp commented 2 years ago

@gmryu Thanks. I have a try then.

robotsp commented 2 years ago

@gmryu I have adjusted the code in fairseq_cli/train.py, I defined a pruning function:

def prune_model_global_unstructured(model, layer_type, proportion):
    from torch.nn.utils import prune
    import torch.nn as nn

    module_tups = []
    for module in model.modules():
        if isinstance(module, layer_type):
            module_tups.append((module, 'weight'))

    prune.global_unstructured(
        parameters=module_tups, pruning_method=prune.L1Unstructured,
        amount=proportion
    )
    for module, _ in module_tups:
        prune.remove(module, 'weight')
    return model

and call it after building the model in the main function:

# Build model and criterion
model = task.build_model(args)
model = prune_model_global_unstructured(model, nn.Linear, 0.5) # nn.Linear is the module you want to prune, and 0.5 is the prop you are ready to prune.

and then use the command fairseq-train to start finetuning. Does it make sense?

robotsp commented 2 years ago

@gmryu I have adjusted the code in fairseq_cli/train.py, I defined a pruning function:

def prune_model_global_unstructured(model, layer_type, proportion):
    from torch.nn.utils import prune
    import torch.nn as nn

    module_tups = []
    for module in model.modules():
        if isinstance(module, layer_type):
            module_tups.append((module, 'weight'))

    prune.global_unstructured(
        parameters=module_tups, pruning_method=prune.L1Unstructured,
        amount=proportion
    )
    for module, _ in module_tups:
        prune.remove(module, 'weight')
    return model

and call it after building the model in the main function:

# Build model and criterion
model = task.build_model(args)
model = prune_model_global_unstructured(model, nn.Linear, 0.5) # nn.Linear is the module you want to prune, and 0.5 is the prop you are ready to prune.

and then use the command fairseq-train to start finetuning. Does it make sense?

It starts working.

gmryu commented 2 years ago

@robotsp Congratulations! It is so nice of you to tell me this works too. Hope it can solve your issue afterwards.

robotsp commented 2 years ago

@gmryu you are so welcome.