Closed tuhinjubcse closed 4 years ago
@thomwolf @sshleifer
Yeah, here is working code for the copy approach, starting with bart-large-cnn weights.
from copy import deepcopy
from transformers import BartForConditionalGeneration
# Get original model
model = BartForConditionalGeneration.from_pretrained('bart-large-cnn')
sd = model.state_dict()
shorter_pos_embeds = sd['model.encoder.embed_positions.weight']
new_config = model.config
new_config.max_position_embeddings = 2048
new_model = BartForConditionalGeneration(new_config)
# If you want to learn everything from scratch, you can stop here.
correctly_shaped_pos_weight = new_model.model.encoder.embed_positions.weight
correctly_shaped_pos_weight[:shorter_pos_embeds.shape[0]] = shorter_pos_embeds
sd['model.decoder.embed_positions.weight'] = correctly_shaped_pos_weight
sd['model.encoder.embed_positions.weight'] = correctly_shaped_pos_weight
new_model.load_state_dict(sd, strict=True)
@sshleifer Can we use this same approach for T5 as well ?
Does it make sense to initialize the weights by copying the first 512 weights multiple times ?
Like : [[0:512], [0:512], [0:512], [0:512]]
So the model just have to learn the "relativeness" between positions ?
I am editing finetune.py
model = SummarizationTrainer(args)
sd = model.model.state_dict()
shorter_pos_embeds = sd['model.encoder.embed_positions.weight']
new_config = model.config
new_config.max_position_embeddings = 4096
new_model = BartForConditionalGeneration(new_config)
correctly_shaped_pos_weight = new_model.model.encoder.embed_positions.weight
correctly_shaped_pos_weight[:shorter_pos_embeds.shape[0]] = shorter_pos_embeds
sd['model.decoder.embed_positions.weight'] = correctly_shaped_pos_weight
sd['model.encoder.embed_positions.weight'] = correctly_shaped_pos_weight
new_model.load_state_dict(sd, strict=True)
model.model = new_model.cuda()
trainer = generic_train(model, args)
INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /u/tuhin1/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b
INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /u/tuhin1/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/facebook/bart-large/pytorch_model.bin from cache at /u/tuhin1/.cache/torch/transformers/2e7cae41bb1dd1f18e498ff4ff0ea85f7e9bc2b637439e2d95c485c5d5bdd579.5be2a88ec29f5969270f98902db392beab8be8a6a7ecc588d410ada3e32c4263
INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias']
INFO:transformers.modeling_utils:Weights from pretrained model not used in BartForConditionalGeneration: ['encoder.version', 'decoder.version']
INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]
INFO:lightning:Using 16bit precision.
/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:23: RuntimeWarning: You have defined a `val_dataloader()` and have defined a `validation_step()`, you may also want to define `validation_epoch_end()` for accumulating stats.
warnings.warn(*args, **kwargs)
Traceback (most recent call last):
File "/dccstor/tuhinstor/transformers/examples/finetune1.py", line 196, in <module>
main(args)
File "/dccstor/tuhinstor/transformers/examples/finetune1.py", line 177, in main
trainer = generic_train(model, args)
File "/dccstor/tuhinstor/transformers/examples/lightning_base.py", line 278, in generic_train
trainer.fit(model)
File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 765, in fit
self.single_gpu_train(model)
File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 484, in single_gpu_train
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/pytorch_lightning/trainer/optimizers.py", line 18, in init_optimizers
optim_conf = model.configure_optimizers()
File "/dccstor/tuhinstor/transformers/examples/lightning_base.py", line 87, in configure_optimizers
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
File "/dccstor/tuhinstor/transformers/src/transformers/optimization.py", line 117, in __init__
super().__init__(params, defaults)
File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/torch/optim/optimizer.py", line 51, in __init__
self.add_param_group(param_group)
File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/torch/optim/optimizer.py", line 202, in add_param_group
raise ValueError("can't optimize a non-leaf Tensor")
ValueError: can't optimize a non-leaf Tensor
Any idea to resolve this @sshleifer
Update I fixed it
model = SummarizationTrainer(args)
sd = model.model.state_dict()
shorter_pos_embeds = sd['model.encoder.embed_positions.weight']
new_config = model.config
new_config.max_position_embeddings = 3000
new_model = BartForConditionalGeneration(new_config)
correctly_shaped_pos_weight = new_model.model.encoder.embed_positions.weight.cuda()
correctly_shaped_pos_weight[:shorter_pos_embeds.shape[0]] = shorter_pos_embeds.cuda()
sd['model.decoder.embed_positions.weight'] = correctly_shaped_pos_weight
sd['model.encoder.embed_positions.weight'] = correctly_shaped_pos_weight
new_model.load_state_dict(sd, strict=True)
model.model = new_model.cuda()
trainer = generic_train(model, args)
However i get OOM , I have a 32 GB NVIDIA V100 , i feel it caches a bit too much any work around ? RuntimeError: CUDA out of memory. Tried to allocate 550.00 MiB (GPU 0; 31.72 GiB total capacity; 26.46 GiB already allocated; 81.88 MiB free; 4.04 GiB cached)
My batch size is 1
@tuhinjubcse could you solve the OOM error? is it possible for you to share your working script for finetuning with longer sequence? appreciate it.
I am editing finetune.py
model = SummarizationTrainer(args) sd = model.model.state_dict() shorter_pos_embeds = sd['model.encoder.embed_positions.weight'] new_config = model.config new_config.max_position_embeddings = 4096 new_model = BartForConditionalGeneration(new_config) correctly_shaped_pos_weight = new_model.model.encoder.embed_positions.weight correctly_shaped_pos_weight[:shorter_pos_embeds.shape[0]] = shorter_pos_embeds sd['model.decoder.embed_positions.weight'] = correctly_shaped_pos_weight sd['model.encoder.embed_positions.weight'] = correctly_shaped_pos_weight new_model.load_state_dict(sd, strict=True) model.model = new_model.cuda() trainer = generic_train(model, args)
INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /u/tuhin1/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /u/tuhin1/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/facebook/bart-large/pytorch_model.bin from cache at /u/tuhin1/.cache/torch/transformers/2e7cae41bb1dd1f18e498ff4ff0ea85f7e9bc2b637439e2d95c485c5d5bdd579.5be2a88ec29f5969270f98902db392beab8be8a6a7ecc588d410ada3e32c4263 INFO:transformers.modeling_utils:Weights of BartForConditionalGeneration not initialized from pretrained model: ['final_logits_bias'] INFO:transformers.modeling_utils:Weights from pretrained model not used in BartForConditionalGeneration: ['encoder.version', 'decoder.version'] INFO:lightning:GPU available: True, used: True INFO:lightning:CUDA_VISIBLE_DEVICES: [0] INFO:lightning:Using 16bit precision. /dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:23: RuntimeWarning: You have defined a `val_dataloader()` and have defined a `validation_step()`, you may also want to define `validation_epoch_end()` for accumulating stats. warnings.warn(*args, **kwargs) Traceback (most recent call last): File "/dccstor/tuhinstor/transformers/examples/finetune1.py", line 196, in <module> main(args) File "/dccstor/tuhinstor/transformers/examples/finetune1.py", line 177, in main trainer = generic_train(model, args) File "/dccstor/tuhinstor/transformers/examples/lightning_base.py", line 278, in generic_train trainer.fit(model) File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 765, in fit self.single_gpu_train(model) File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 484, in single_gpu_train self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/pytorch_lightning/trainer/optimizers.py", line 18, in init_optimizers optim_conf = model.configure_optimizers() File "/dccstor/tuhinstor/transformers/examples/lightning_base.py", line 87, in configure_optimizers optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) File "/dccstor/tuhinstor/transformers/src/transformers/optimization.py", line 117, in __init__ super().__init__(params, defaults) File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/torch/optim/optimizer.py", line 51, in __init__ self.add_param_group(param_group) File "/dccstor/tuhinstor/condatuhin/envs/BARTQA/lib/python3.7/site-packages/torch/optim/optimizer.py", line 202, in add_param_group raise ValueError("can't optimize a non-leaf Tensor") ValueError: can't optimize a non-leaf Tensor
Any idea to resolve this @sshleifer
Which file are you editing?
❓ Questions & Help
Details
Fairseq folks say we can finetune BART model with longer seq_len on our custom training data. They pre-trained bart on 512 seq_len and during fine-tuning, they use 1024 seq_len. One can raise it further (let's say 2048) and finetune.
For above, we would need to adjust positional embeddings by either:
with 2 being recommended. If I have to do it any pointers where to change the code ?
https://github.com/pytorch/fairseq/issues/1685