huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.02k stars 27.01k forks source link

Can I training a bart model from scratch by transformers? #5096

Closed ScottishFold007 closed 4 years ago

ScottishFold007 commented 4 years ago

Can I training a bart model from scratch by transformers?

patrickvonplaten commented 4 years ago

Yes

ScottishFold007 commented 4 years ago

Yes

That' s awesome!Can you give a code to show? I'm grateful!

patrickvonplaten commented 4 years ago

So from the paper: https://arxiv.org/pdf/1910.13461.pdf, you can see that Bart is trained on denoising input sequences in almost any possible way.

One way could be for BartForConditionalGeneration:

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

tok = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration(BartConfig())

input_string = "My dog is <mask> </s>"
decoder_input_string = "<s> My dog is cute"
labels_string = "My dog is cute </s>"

input_ids = tok(input_string, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids =tok(decoder_input_string, add_special_tokens=False, return_tensors="pt").input_ids
labels = tok(labels_string, add_special_tokens=False, return_tensors="pt").input_ids

loss = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)[0]
patrickvonplaten commented 4 years ago

Pinging @sshleifer to make sure I did not forget anything

ScottishFold007 commented 4 years ago

Pinging @sshleifer to make sure I did not forget anything

Actually, I was going to ask. how train a model from zero to one. For example, I want to train a Chinese bart model.

tomhosking commented 4 years ago

Here's a working example for this, including batching:

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

tok = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration(BartConfig())

input_batch = ["My dog is <mask></s>", "It loves to play in the <mask></s>"]
decoder_input_batch = ["<s>My dog is cute", "<s>It loves to play in the park"]
labels_batch = ["My dog is cute</s>", "It loves to play in the park</s>"]

input_ids = tok.batch_encode_plus(input_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
decoder_input_ids = tok.batch_encode_plus(decoder_input_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
labels = tok.batch_encode_plus(labels_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids

loss = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)[0]

>>> tensor(10.9981, device='cuda:0', grad_fn=<NllLossBackward>)

ScottishFold007 commented 4 years ago

Here's a working example for this, including batching:

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

tok = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration(BartConfig())

input_batch = ["My dog is <mask></s>", "It loves to play in the <mask></s>"]
decoder_input_batch = ["<s>My dog is cute", "<s>It loves to play in the park"]
labels_batch = ["My dog is cute</s>", "It loves to play in the park</s>"]

input_ids = tok.batch_encode_plus(input_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
decoder_input_ids = tok.batch_encode_plus(decoder_input_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids
labels = tok.batch_encode_plus(labels_batch, add_special_tokens=False, return_tensors="pt", padding=True).input_ids

loss = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)[0]

>>> tensor(10.9981, device='cuda:0', grad_fn=<NllLossBackward>)

input_batch = ["My dog is ", "It loves to play in the "] decoder_input_batch = ["My dog is cute", "It loves to play in the park"] labels_batch = ["My dog is cute", "It loves to play in the park"]

If I have a text document, each line of a paragraph, how do I rewrite the data input on it? Thanks!

swethmandava commented 3 years ago

@tomhosking the paper indicates that it uses both sentence permutation (loss is propagated from all tokens instead of only masked tokens) and infilling (include only one mask token for multiple consecutive masks). would this be a correct input?

input_batch = ["\It is \<mask> retriever. My dog is \<mask>\", "\There \<mask> in SF. It loves to play in the \<mask>\"] decoder_input_batch = ["\\My dog is cute. It is a golden retriever", "\\It loves to play in the park. There are many parks in SF."] labels_batch = ["\My dog is cute. It is a golden retriever\", "\It loves to play in the park. There are many parks in SF.\"]

(Note: decoder_input_batch starts with \\ due to shift_tokens_right #7961)

jonatasgrosman commented 3 years ago

Sorry for the intrusion, but I think your values are almost correct @swethmandava, except for the masking absence

input_batch = ["<s>It <mask> retriever. My <mask> cute </s>", ... ]
decoder_input_batch = ["</s><s>My dog is cute. It is a golden retriever", ...]
labels_batch = ["<s>My dog is cute. It is a golden retriever</s>", ...]

BTW: This </s> token at the beginning of decode's input is kind of weird to me, but it's inherited from the fairseq original code. If you wanna train the model from scratch with random weights I think you can go without this... or maybe this trick is important for convergence, we never know :grin:

HuipengXu commented 3 years ago

Will only 15% mask in the encoder input cause some kind of leakage? The language model in the decoder cannot learn correctly

prajdabre commented 3 years ago

If anyone wants to train their MBART model then feel free to use this. https://github.com/prajdabre/yanmtt

Contributions are welcome!

jbmaxwell commented 2 years ago

Sorry for the intrusion, but I think your values are almost correct @swethmandava, except for the masking absence

input_batch = ["<s>It <mask> retriever. My <mask> cute </s>", ... ]
decoder_input_batch = ["</s><s>My dog is cute. It is a golden retriever", ...]
labels_batch = ["<s>My dog is cute. It is a golden retriever</s>", ...]

BTW: This </s> token at the beginning of decode's input is kind of weird to me, but it's inherited from the fairseq original code. If you wanna train the model from scratch with random weights I think you can go without this... or maybe this trick is important for convergence, we never know 😁

I have a non-natural language dataset where I haven't actually been including <s> and </s> since they don't add any value (and need to be removed later anyway). To work with that, should I insert a pad token at the start of the decoder_input representation (and truncate to max_length)?

Haiming94 commented 2 years ago

So from the paper: https://arxiv.org/pdf/1910.13461.pdf, you can see that Bart is trained on denoising input sequences in almost any possible way.

One way could be for BartForConditionalGeneration:

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig

tok = BartTokenizer.from_pretrained("facebook/bart-large")
model = BartForConditionalGeneration(BartConfig())

input_string = "My dog is <mask> </s>"
decoder_input_string = "<s> My dog is cute"
labels_string = "My dog is cute </s>"

input_ids = tok(input_string, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids =tok(decoder_input_string, add_special_tokens=False, return_tensors="pt").input_ids
labels = tok(labels_string, add_special_tokens=False, return_tensors="pt").input_ids

loss = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, labels=labels)[0]

Hi, do you have a script to build the training dataset of BART pertain, thanks

BramVanroy commented 2 years ago

@patrickvonplaten @sshleifer Did anyone ever come around to creating a notebook/script for BART pretraining? (In a linked issue you mentioned it was on the to-do list.)

The core difficulty is having a canonical implementation for the data preprocessing (BART is more than just token masking, I believe: e.g.,span masking, shuffling). But a full pretrain pipeline here or in fairseq is also sorely missing.

patrickvonplaten commented 2 years ago

Sadly not :-/ We now have on for Flax in #18297 - could you try to copy-paste the preprocessing logic into a PyTorch one maybe?

BramVanroy commented 2 years ago

@patrickvonplaten I've been porting the fairseq implementation to a PyTorch dataloader format. I found that the Flax implementation in HF lacks adding noise for 0-length spans and has some slightly diverging implementation so it was more straightforward to start from the fairseq implementation. I am now especially testing the data processing to get it as close as possible to fairseq's implementation (although it is my believe that there's a bug in their code).

I would like to add a full pytorch example for DLM training of BART in the coming days/weeks but I could use some code reviews in doing that to feel more comfortable. Would that be possible?

patrickvonplaten commented 2 years ago

Sure, happy to take a look!

prajdabre commented 2 years ago

Hi

I remember posting this a year ago but I've written an entire toolkit for this purpose. Feel free to use it. https://github.com/prajdabre/yanmtt

I've also created a simple notebook for the same (scroll to the pretraining part): https://colab.research.google.com/drive/1ovlA_h0ggblawqR-yCgRs3uRjxFJ8K0l?usp=sharing

BramVanroy commented 2 years ago

Hi Raj, thank you for this. I had come across it but your script seems to have a lot of additional things going on so that it is hard to extract the basics. I also found that you implement word/span masking but not the other things like adding noise or randomly swap a masked token for a random token, so not completely like the original implementation (but correct me if I'm wrong!) .

I think your library can be very useful to be used as a separate library, thanks! In addition I'll try add a PR in transformers for an succinct example to use within transformers with the Trainer, with data processing close the fairseq implementation.

prajdabre commented 2 years ago

Hi,

My focus was more on mbart and mt5 which looked only at span masking and reordering. I'm not sure if token replacement will have that big of an impact but can be easily implemented in 1 line. To my understanding, span masking is responsible for majority of the gains. The notebook contains a more watered down version of the masking method in my toolkit. You could consider that version and build on top of it easily.

CountingMstar commented 1 year ago

Hey guys, I would want to know how to pre-training BART model from scratch. Anyone who know about this? BART, pegasus or other text summarization models are okay for me.