Open robotsp opened 2 years ago
You may try
from fairseq import checkpoint_utils
checkpoint_utils.torch_persistent_save(state_dict, filename, async_write=false)
print(f"Finished saving checkpoint to {filename}")
relative info: checkpoint_utils , trainer's save_checkpoint
Is this what you need?
@gmryu Thanks. Yes, I think so. But I have tried and it raises an error when I reloaded the model:
args = state["args"] if arg_overrides is not None: for arg_name, arg_val in arg_overrides.items(): KeyError: 'args'
I think some state_dicts could be missing, when I save. My scripts:
model = TransformerModel.from_pretrained( '/path_to_model', checkpoint_file='model_name', data_name_or_path='/path_to_data' )
state_dict = model.state_dict()
## weights altered
from fairseq import checkpoint_utils checkpoint_utils.torch_persistent_save(state_dict, "new_model", async_write=false)
## reload the new model and KeyError
@robotsp I guess that will happen since a model state dict is not a checkpoint for fairseq. I found you can use torch.save instead. The flow is you load the old model again and do load_state_dict( your edited state_dict)
import torch
from fairseq.models.transformer.transformer_legacy import TransformerModel
ddp="{directory of dict.xxx.txt}"
new_path="{the new model file path}" # note: this is not a checkpoint, it stores only model.state_dict()
model = TransformerModel.from_pretrained( '/path_to_old_model', checkpoint_file='model_name', data_name_or_path=ddp)
torch.save(model.state_dict(), new_path)
mdl2 = TransformerModel.from_pretrained( '/path_to_old_model', checkpoint_file='model_name', data_name_or_path=ddp)
# You can initialize a TransformerModel instance whose structure is the same instead.
# Loading takes a long time and tell python to create a same structure model is annoying as well. So pick your best.
mdl2.load_state_dict( torch.load(new_path) )
# It shall tell you "All keys matched successfully"
Well, one more idea is to write a new Transformer.py, which inherits the original and has a new load method. It may reduce your loading time as well. Good luck.
@robotsp I guess that will happen since a model state dict is not a checkpoint for fairseq. I found you can use torch.save instead. The flow is you load the old model again and do load_state_dict( your edited state_dict)
import torch from fairseq.models.transformer.transformer_legacy import TransformerModel ddp="{directory of dict.xxx.txt}" new_path="{the new model file path}" # note: this is not a checkpoint, it stores only model.state_dict() model = TransformerModel.from_pretrained( '/path_to_old_model', checkpoint_file='model_name', data_name_or_path=ddp) torch.save(model.state_dict(), new_path) mdl2 = TransformerModel.from_pretrained( '/path_to_old_model', checkpoint_file='model_name', data_name_or_path=ddp) # You can initialize a TransformerModel instance whose structure is the same instead. # Loading takes a long time and tell python to create a same structure model is annoying as well. So pick your best. mdl2.load_state_dict( torch.load(new_path) ) # It shall tell you "All keys matched successfully"
Well, one more idea is to write a new Transformer.py, which inherits the original and has a new load method. It may reduce your loading time as well. Good luck.
Looks great! How can I save mdl2 as a checkpoint after I load the state_dict? Is there any method from fairseq?
@robotsp I guess at this point, you do not load as a model but load the checkpoint itself.
ckpt_state_dict=torch.load("{your old model checkpoint.pt}")
print(ckpt_state_dict) # it should have something like "model","cfg"
# by editing what inside "model" as tensors, you can edit your model.
# downside is you cannot treat is as a model instance
# after adjustment
torch.save(ckpt_state_dict, "{new checkpoint.pt}") #
I guess at this point, you do not load as a model but load the checkpoint itself.
@gmryu Thanks for your reply, I am going to finetune the new model but I cannot load and train it by using fairseq-train
. Could you please provide the code?
Including the steps:
fairseq-train
@robotsp I gave the code before. Please at least tell me what happen when you tried these. I will just do this again if you did not see that or you missed something.
First, you need to understand what a fairseq checkpoint is:
>>> import torch
>>> sd=torch.load("my_checkpoint.pt")
>>> print(sd.keys())
dict_keys(['cfg', 'args', 'model', 'optimizer_history', 'extra_state', 'last_optimizer_state'])
So you know the checkpoint has above elements. Look into model
you find:
>>> print(sd["model"].keys())
dict_keys(['encoder.version', 'encoder.embed_tokens.weight', 'encoder.embed_positions._float_tensor', 'encoder.layers.0.self_attn.k_proj.weight', 'encoder.layers.0.self_attn.k_proj.bias', 'encoder.layers.0.self_attn.v_proj.weight', ..... ])
Obviously, you can access those weights(tensor) by using keys. Says sd["model"]["encoder.embed_tokens.weight"]
Edit them inside the dict. You do not need to create a TransformerModel or other model instance.
Once you have done with editing, you can save and load later:
>>> torch.save(sd, "tmp2.pt")
>>> exit()
python
>>> import torch
>>> sd2=torch.load("tmp2.pt")
>>> print(sd2.keys())
dict_keys(['cfg', 'args', 'model', 'optimizer_history', 'extra_state', 'last_optimizer_state'])
This tmp2.pt
can be used by fairseq-train
or fairseq-hydra-train
.
So I hope you understand what I wrote and please give out your error log if you encounter something.
Thanks @gmryu , it works. Btw, do you know any methods of pruning (L1 pruning, etc..) fairseq models?
@robotsp I have not pruned or distilled any models. But if pytorch provide methods to prune a module, I believe you can cast that method upon the ["model"]["your weight"]. (It might need 2 parts: 1. you load the model and prune it, save the pruned weights only by torch or numpy. 2. you load the checkpoint and updates its weight according to your saved weights.
If you succeeded in pruning models, I would be grateful if you can share it here.
@gmryu I have a try using Microsoft NNI to prune fairseq model, and here is the code,
` import torch from torch import nn from torchvision import models from torch import optim import torchvision.transforms as transforms from torchsummary import summary import numpy as np import random import nni from nni.algorithms.compression.pytorch.pruning import LevelPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import L1NormPruner from nni.algorithms.compression.pytorch.pruning import L2FilterPruner from nni.compression.pytorch import ModelSpeedup from fairseq.models.transformer import TransformerModel
model = TransformerModel.from_pretrained( 'path_to_checkpoint', checkpoint_file='checkpoint.pt', data_name_or_path='path_to_data' )
config_list = [{'sparsity': 0.5, 'op_types':['Linear']}] pruner = L1NormPruner(model, configlist) , masks = pruner.compress() pruner._unwrap_model() ModelSpeedup(model, dummy_input=??, masks_file=masks).speedup_model() `
I checked the document of Microsoft NNI and it only provides a Conv-base demo, do you know about the dummy_input (input size) of Transformer-base model?
ref: https://nni.readthedocs.io/en/latest/tutorials/pruning_quick_start_mnist.html
@robotsp
About that dummy_input
, it is an example of acceptable input for that model.
Like for blenderbot in huggingface, it wrote:
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
"decoder_input_ids": input_ids,
}
return dummy_inputs
And this is by itself accepted by blenderbot's forward(**dummy_input)
So for fairseq transformer, you need a data that suits TransformerModel's forward
From that forward, you probably need:
src_tokens, # 2D tensor batch_size * sentence length. each element inside is no bigger than your vocab size
src_lengths, # 1D tensor, each is the given sentence length
prev_output_tokens, # 2D tensor batch_size * sentence length. each element inside is no bigger than your vocab size
To prune, it probably requires a simple run through( like a evaluation), so you need to provide a data here.
Wish this helps, but I am also new to this field so let me know if this works. I may not be able to reply you too soon.
--
I do not know whether there is a recommended sentence length or batch size. If you are very serious about this, I recommend you try 2-3 different input. Says one is actually taken from your data, one is tensors of random integers.
@gmryu
I tried to assign tensors of random integers as input and it raises an error at
ModelSpeedup(model, dummy_input=torch.rand([64,50,50,64,50]), masks_file=masks).speedup_model()
error msg:
NotImplementedError Traceback (most recent call last)
<ipython-input-53-4204d8c2effc> in <module>
----> 1 ModelSpeedup(model, dummy_input=torch.rand([64,50,50,64,50]), masks_file=masks).speedup_model()
/usr/local/lib/python3.6/site-packages/nni/compression/pytorch/speedup/compressor.py in __init__(self, model, dummy_input, masks_file, map_location, batch_dim, confidence)
55 self.batch_dim = batch_dim
56 self.dummy_input, self.device = self._random_model_input(dummy_input, confidence, batch_dim)
---> 57 self.torch_graph = build_module_graph(model, self.dummy_input)
58 # dict object to save the auto inferences objects of the submodules
59 self.auto_inferences = {}
/usr/local/lib/python3.6/site-packages/nni/common/graph_utils.py in build_module_graph(model, dummy_input)
22
23 def build_module_graph(model, dummy_input):
---> 24 return TorchModuleGraph(model, dummy_input)
25
26
/usr/local/lib/python3.6/site-packages/nni/common/graph_utils.py in __init__(self, model, dummy_input, traced_model)
250
251 def __init__(self, model=None, dummy_input=None, traced_model=None):
--> 252 super().__init__(model, dummy_input, traced_model)
253 self.global_count = 0
254 self.reused_module = set()
/usr/local/lib/python3.6/site-packages/nni/common/graph_utils.py in __init__(self, model, dummy_input, traced_model)
64 elif model is not None and dummy_input is not None:
65 self.bound_model = model
---> 66 self._trace(model, dummy_input)
67 else:
68 raise Exception(
/usr/local/lib/python3.6/site-packages/nni/common/graph_utils.py in _trace(self, model, dummy_input)
76 # only pytorch with version greater than 1.6.0 has the strict option
77 kw_args['strict'] = False
---> 78 self.trace = torch.jit.trace(model, dummy_input, **kw_args)
79 torch._C._jit_pass_inline(self.trace.graph)
80 model.train(training)
/usr/local/lib/python3.6/site-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
740 strict,
741 _force_outplace,
--> 742 _module_class,
743 )
744
/usr/local/lib/python3.6/site-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
938 var_lookup_fn,
939 strict,
--> 940 _force_outplace,
941 )
942 check_trace_method = module._c._get_method(method_name)
/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
723 input = result
724 if torch._C._get_tracing_state():
--> 725 result = self._slow_forward(*input, **kwargs)
726 else:
727 result = self.forward(*input, **kwargs)
/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
707 recording_scopes = False
708 try:
--> 709 result = self.forward(*input, **kwargs)
710 finally:
711 if recording_scopes:
/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py in _forward_unimplemented(self, *input)
173 registered hooks while the latter silently ignores them.
174 """
--> 175 raise NotImplementedError
176
177
NotImplementedError:
@gmryu Does it mean nni does not support fairseq framework yet?
@robotsp sorry, I cannot reply you during my work time. (Also, have you actually read all I wrote? Is my phrasing too bad to understand?)
Nonetheless, you may need to know:
model
you used here? Can it be fairseq's TransformerModel?
If you are using fairseq's TransformerModel as the model
,
it requires a batch to be
# if the model is a fairseq Transformer,
dummy_input={
"src_tokens": # a 2D tensor (batch_size * sentence length) , #each element inside is no bigger than your vocab size
"src_lengths": # a 1D tensor, each is the given sentence length
"prev_output_tokens": # a 2D tensor (batch_size * sentence length). each element inside is no bigger than your vocab size :
}
If it cannot be a TransformerModel, you need to find out what its forward need or what kind of a batch that model can accept.
@robotsp sorry, I cannot reply you during my work time. (Also, have you actually read all I wrote? Is my phrasing too bad to understand?)
Nonetheless, you may need to know:
- dummy_input is literally a "dummy" batch of data to the model. It is a python dicitionary, not a single tensor.
- what class is the
model
you used here? Can it be fairseq's TransformerModel? If you are using fairseq's TransformerModel as themodel
, it requires a batch to be# if the model is a fairseq Transformer, dummy_input={ "src_tokens": # a 2D tensor (batch_size * sentence length) , #each element inside is no bigger than your vocab size "src_lengths": # a 1D tensor, each is the given sentence length "prev_output_tokens": # a 2D tensor (batch_size * sentence length). each element inside is no bigger than your vocab size : }
If it cannot be a TransformerModel, you need to find out what its forward need or what kind of a batch that model can accept.
@gmryu , I tried using Pytorch utils nn.prune method to prune fairseq model, it can successfully set those less important weights to zeros. And I save the state_dict using torch.save
. The question is, when I need to load it and retrain the model, the zeros weights will be updated again.
Here is the steps:
fairseq-train
to retrain the model. (But the prune weights are updated again..)How can I freeze these unstructured pruned weights and avoid its updating during retraining process?
@robotsp if you want to freeze a tensor from updating, I guess the simple way is to freeze it in the code directly.
For example, you can copy a TransformerModel implementation, freeze those weights in its build_model
method and load_state_dict` method. Then you use --user-dir to import your custom model, --arch to tell fairseq to use your custom model class.
(or you may try to freeze some tensors,i.e. model weights, and reload again. It will identifies whether saving tensors keeps their frozen state or not.)
@gmryu I have checked some weight freezing methods, but all of them are to freeze a whole module rather than some tensors (in my case, I only want to freeze those zero weights from pruning). The requires_grad
influences the whole module, right?
@robotsp check out this: https://discuss.pytorch.org/t/how-do-i-freeze-the-specific-weights-in-a-layer/104722/5
(or you may try to freeze some tensors,i.e. model weights, and reload again. It will identifies whether saving tensors keeps their frozen state or not.)
It is similar to that pruning mechanism from torch.nn
, but for fairseq model, it can not directly apply but a workaround steps:
Use `fairseq.from_pretrained`` to load the model, and prune it using nn.prune. //get new state_dict with pruned weights. Use torch.load() to get a new model and assign the new state_dict to it. Use torch.save() to save the new model with pruned weights. Use fairseq-train to retrain the model. (But the prune weights are updated again..)
Does any method I can freeze some zeros tensor after reload?
If you want to only freeze some slice of a single weight tensor, in short answer it is difficult and you need to modify fairseq-train. The previous link talks about it.
@gmryu , I found that params.requires_grad=False isn't working for fairseq model to freeze weights?
torch.save()
fairseq-train
, but the weights are still updated.Is there any problems in my implementation?
@robotsp Sorry, I have no confidence in these since I have not tried it yet. I guess I would need to try myself later. However, I only have time in weekend so it may take times.
If you want to solve it now, I would suggest adjusting the TransformerModel implementation. To make parts of the model freeze during fairseq-train.
I believe build_model
method and load_state_dict
method are where you can freeze those parameters.
@gmryu Thanks. I have a try then.
@gmryu I have adjusted the code in fairseq_cli/train.py, I defined a pruning function:
def prune_model_global_unstructured(model, layer_type, proportion):
from torch.nn.utils import prune
import torch.nn as nn
module_tups = []
for module in model.modules():
if isinstance(module, layer_type):
module_tups.append((module, 'weight'))
prune.global_unstructured(
parameters=module_tups, pruning_method=prune.L1Unstructured,
amount=proportion
)
for module, _ in module_tups:
prune.remove(module, 'weight')
return model
and call it after building the model in the main function:
# Build model and criterion
model = task.build_model(args)
model = prune_model_global_unstructured(model, nn.Linear, 0.5) # nn.Linear is the module you want to prune, and 0.5 is the prop you are ready to prune.
and then use the command fairseq-train
to start finetuning. Does it make sense?
@gmryu I have adjusted the code in fairseq_cli/train.py, I defined a pruning function:
def prune_model_global_unstructured(model, layer_type, proportion): from torch.nn.utils import prune import torch.nn as nn module_tups = [] for module in model.modules(): if isinstance(module, layer_type): module_tups.append((module, 'weight')) prune.global_unstructured( parameters=module_tups, pruning_method=prune.L1Unstructured, amount=proportion ) for module, _ in module_tups: prune.remove(module, 'weight') return model
and call it after building the model in the main function:
# Build model and criterion model = task.build_model(args) model = prune_model_global_unstructured(model, nn.Linear, 0.5) # nn.Linear is the module you want to prune, and 0.5 is the prop you are ready to prune.
and then use the command
fairseq-train
to start finetuning. Does it make sense?
It starts working.
@robotsp Congratulations! It is so nice of you to tell me this works too. Hope it can solve your issue afterwards.
@gmryu you are so welcome.
❓ Questions and Help
Before asking:
What is your question?
Code
What have you tried?
What's your environment?
pip
, source):