huggingface / transformers

đŸ¤— Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.79k stars 26.96k forks source link

BertGenerationDecoder .generate() issue during inference with PyTorch Lightning #9686

Closed anicolson closed 3 years ago

anicolson commented 3 years ago

Environment info

Who can help

TextGeneration: @TevenLeScao Text Generation: @patrickvonplaten examples/seq2seq: @patil-suraj

Information

I am using BertGenerationEncoder and BertGenerationDecoder. I am using transformers in combination with PyTorch lightning.

At inference, .generate() outputs the same thing for each input.

I am unsure of why this is occurring, my only hunch is that PyTorch lighting is somehow blocking the outputs of the encoder to reach the decoder for cross-attention? As the outputs seem as though the decoder is just given the [BOS] token only for each input during inference.

The task that I am demonstrating this issue on is:

I have had this problem occur on different tasks as well. Using WMT'14 English to German to demonstrate.

To reproduce

I have tried to simplify this down, but unfortunately, the example is still long. Sorry about that. Please let me know if something does not work.

If torchnlp is not installed: pip install pytorch-nlp If pytorch_lightning is not installed: pip install pytorch-lightning

from torchnlp.datasets.wmt import wmt_dataset
import torch
import torch.nn as nn
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.metrics.functional.nlp import bleu_score
import pytorch_lightning as pl
from transformers import (
    BertGenerationConfig,
    BertGenerationEncoder,
    BertGenerationDecoder,
)
from transformers import AutoTokenizer
import os
import numpy as np
import multiprocessing

class Dataset(LightningDataModule):

    def __init__(
        self,
        mbatch_size,
        dataset_path,
        encoder_tokenizer,
        decoder_tokenizer,
        max_len=None,
        **kwargs,
    ):

        super().__init__()
        self.mbatch_size = mbatch_size
        self.dataset_path = dataset_path
        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer
        self.max_len = max_len

        ## Number of workers for DataLoader
        self.n_workers = multiprocessing.cpu_count()

    def setup(self, stage=None):

        ## Assign train & validation sets
        if stage == "fit" or stage is None:
            train_iterator, val_iterator = wmt_dataset(
                directory=self.dataset_path,
                train=True,
                dev=True,
            )
            self.train_set = Set(
                train_iterator,
                self.encoder_tokenizer,
                self.decoder_tokenizer,
                self.max_len,
            )
            self.val_set = Set(
                val_iterator,
                self.encoder_tokenizer,
                self.decoder_tokenizer,
                self.max_len,
            )

        ## Assign test set
        if stage == "test" or stage is None:
            test_iterator = wmt_dataset(directory=self.dataset_path, test=True)
            self.test_set = Set(
                test_iterator,
                self.encoder_tokenizer,
                self.decoder_tokenizer,
                self.max_len,
            )

    def train_dataloader(self):

        return DataLoader(
            self.train_set,
            batch_size=self.mbatch_size,
            num_workers=self.n_workers,
            shuffle=True,
        )

    def val_dataloader(self):

        return DataLoader(
            self.val_set,
            batch_size=self.mbatch_size,
            num_workers=self.n_workers,
        )

    def test_dataloader(self):

        return DataLoader(
            self.test_set,
            batch_size=self.mbatch_size,
            num_workers=self.n_workers,
        )

class Set(torch.utils.data.Dataset):

    def __init__(
        self,
        iterator,
        encoder_tokenizer,
        decoder_tokenizer,
        max_len,
    ):
        self.iterator = iterator
        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer
        self.n_examples = len(self.iterator)
        self.max_len = max_len

    def __getitem__(self, index):

        example = self.iterator[index]

        english_encoded = self.encoder_tokenizer(
            example["en"],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
        )
        german_encoded = self.decoder_tokenizer(
            example["de"],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
        )

        return {
            "input_ids": english_encoded["input_ids"][0],
            "token_type_ids": english_encoded["token_type_ids"][0],
            "attention_mask": english_encoded["attention_mask"][0],
            "decoder_input_ids": german_encoded["input_ids"][0],
            "decoder_token_type_ids": german_encoded["token_type_ids"][0],
            "decoder_attention_mask": german_encoded["attention_mask"][0],
        }

    def __len__(self):
        return self.n_examples

class BERT2BERT(nn.Module):

    def __init__(self, **kwargs):
        super(BERT2BERT, self).__init__()

        assert "ckpt_base" in kwargs, "ckpt_base must be passed."
        self.ckpt_base = kwargs["ckpt_base"]

        ## Tokenizer
        assert (
            "encoder_tokenizer" in kwargs
        ), "A tokenizer for the encoder must be passed."
        assert (
            "decoder_tokenizer" in kwargs
        ), "A tokenizer for the decoder must be passed."
        self.encoder_tokenizer = kwargs["encoder_tokenizer"]
        self.decoder_tokenizer = kwargs["decoder_tokenizer"]

        ## Encoder
        assert "encoder_init" in kwargs, "Set encoder_init in config file."
        self.encoder_init = kwargs["encoder_init"]
        ckpt_dir = os.path.join(self.ckpt_base, self.encoder_init)
        self.encoder = BertGenerationEncoder.from_pretrained(ckpt_dir)

        ## Decoder
        assert "decoder_init" in kwargs, "Set decoder_init in config file."
        self.decoder_init = kwargs["decoder_init"]
        ckpt_dir = os.path.join(self.ckpt_base, self.decoder_init)
        config = BertGenerationConfig.from_pretrained(ckpt_dir)
        config.is_decoder = True
        config.add_cross_attention = True
        config.bos_token_id = self.decoder_tokenizer.cls_token_id
        config.eos_token_id = self.decoder_tokenizer.sep_token_id
        config.pad_token_id = self.decoder_tokenizer.pad_token_id
        config.max_length = kwargs["max_length"] if "max_length" in kwargs else 20
        config.min_length = kwargs["min_length"] if "min_length" in kwargs else 10
        config.no_repeat_ngram_size = (
            kwargs["no_repeat_ngram_size"] if "no_repeat_ngram_size" in kwargs else 0
        )
        config.early_stopping = (
            kwargs["early_stopping"] if "early_stopping" in kwargs else False
        )
        config.length_penalty = (
            kwargs["length_penalty"] if "length_penalty" in kwargs else 1.0
        )
        config.num_beams = kwargs["num_beams"] if "num_beams" in kwargs else 1
        self.decoder = BertGenerationDecoder.from_pretrained(
            ckpt_dir,
            config=config,
        )

    def forward(self, x):

        ## Get last hidden state of the encoder
        encoder_hidden_state = self.encoder(
            input_ids=x["input_ids"],
            attention_mask=x["attention_mask"],
        ).last_hidden_state

        ## Teacher forcing: labels are given as input
        outp = self.decoder(
            input_ids=x["decoder_input_ids"],
            attention_mask=x["decoder_attention_mask"],
            encoder_hidden_states=encoder_hidden_state,
        )

        return outp["logits"]

    def generate(self, input_ids, attention_mask):

        ## Get last hidden state of the encoder
        encoder_hidden_state = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        ).last_hidden_state

        print("\n Output of encoder:")
        print(encoder_hidden_state)

        bos_ids = (
            torch.ones(
                (encoder_hidden_state.size()[0], 1),
                dtype=torch.long,
                device=self.decoder.device,
            )
            * self.decoder.config.bos_token_id
        )

        ## Autoregresively generate predictions
        return self.decoder.generate(
            input_ids=bos_ids,
            encoder_hidden_states=encoder_hidden_state,
        )

class Seq2Seq(pl.LightningModule):

    def __init__(
        self,
        encoder_init,
        decoder_init,
        encoder_tokenizer,
        decoder_tokenizer,
        permute_outp=False,
        ckpt_base="",
        ver="tmp",
        print_model=True,
        **kwargs,
    ):
        super(Seq2Seq, self).__init__()
        self.save_hyperparameters()

        self.permute_outp = permute_outp
        self.ckpt_base = ckpt_base
        self.ver = ver

        self.encoder_tokenizer = encoder_tokenizer
        self.decoder_tokenizer = decoder_tokenizer

        self.seq2seq = BERT2BERT(
            encoder_init=encoder_init,
            decoder_init=decoder_init,
            encoder_tokenizer=encoder_tokenizer,
            decoder_tokenizer=decoder_tokenizer,
            ckpt_base=ckpt_base,
            **kwargs,
        )

        ## Loss function
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):

        ## Iterate through the networks
        return self.seq2seq(x)

    def training_step(self, batch, batch_idx):

        ## Target
        y = batch["decoder_input_ids"]

        ## Inference
        y_hat = self(batch)

        ## Permute output
        if self.permute_outp:
            y_hat = y_hat.permute(*self.permute_outp)

        ## Loss
        train_loss = self.loss(y_hat, y)

        ## Compute and log metrics
        logs = {"train_loss": train_loss}
        self.log_dict(logs, on_step=False, on_epoch=True)

        ######### TEMPORARY!!!
        if batch_idx % 100 == 0:
            pred = self.seq2seq.generate(
                batch["input_ids"],
                batch["attention_mask"],
            )
            pred_str = self.decoder_tokenizer.batch_decode(pred, skip_special_tokens=True)

            ref_str = self.decoder_tokenizer.batch_decode(y, skip_special_tokens=True)

            print("\nTraining reference labels:")
            print(ref_str)
            print("\n Training predictions:")
            print(pred_str)
            print("\n\n")

        ## Return training loss
        return train_loss

    def validation_step(self, batch, batch_idx):

        print("\n\n\n Validation input_ids:")
        print(batch["input_ids"])

        ## Generate outputs autoregresively
        pred = self.seq2seq.generate(
            batch["input_ids"],
            batch["attention_mask"],
        )

        pred_str = self.decoder_tokenizer.batch_decode(pred, skip_special_tokens=True)
        ref_str = self.decoder_tokenizer.batch_decode(batch["decoder_input_ids"], skip_special_tokens=True)

        print("Validation reference labels:")
        print(ref_str)
        print("Validation predictions:")
        print(pred_str)
        print("\n\n")

        pred_str = [i.split() for i in pred_str]
        ref_str = [i.split() for i in ref_str]

        self.log_dict({"val_bleu": bleu_score(pred_str, ref_str)})

    def test_step(self, batch, batch_idx):
        ## Generate outputs autoregresively
        pred = self.seq2seq.generate(
            batch["input_ids"],
            batch["attention_mask"],
        )

        pred_str = self.decoder_tokenizer.batch_decode(pred, skip_special_tokens=True)
        ref_str = self.decoder_tokenizer.batch_decode(batch["decoder_input_ids"], skip_special_tokens=True)

        pred_str = [i.split() for i in pred_str]
        ref_str = [i.split() for i in ref_str]

        self.log_dict({"test_bleu": bleu_score(pred_str, ref_str)})

    def configure_optimizers(self):
        self.optimisers = [torch.optim.Adam(self.parameters(), lr=4e-5)]
        return self.optimisers

if __name__ == "__main__":

    ckpt_base = ""
    encoder_init = "bert-base-uncased"
    decoder_init = "dbmdz/bert-base-german-uncased"
    dataset_path = ""

    encoder_tokenizer = AutoTokenizer.from_pretrained(
            os.path.join(ckpt_base, encoder_init),
        )
    decoder_tokenizer = AutoTokenizer.from_pretrained(
            os.path.join(ckpt_base, decoder_init),
        )

    dataset = Dataset(
        mbatch_size=4,
        dataset_path=dataset_path,
        encoder_tokenizer=encoder_tokenizer,
        decoder_tokenizer=decoder_tokenizer,
        max_len=512,
    )

    trainer = pl.Trainer(
        max_epochs=2,
        num_sanity_val_steps=0,
        fast_dev_run=True,
        accelerator="ddp" if torch.cuda.device_count() > 1 else None,
        gpus=torch.cuda.device_count() if torch.cuda.is_available() else None,
        precision=16 if torch.cuda.is_available() else 32,
        log_gpu_memory=log_gpu_memory if torch.cuda.is_available() else False,
        plugins=plugins if torch.cuda.device_count() > 1 else None,
    )

    seq2seq = Seq2Seq(
        encoder_init=encoder_init,
        decoder_init=decoder_init,
        encoder_tokenizer=encoder_tokenizer,
        decoder_tokenizer=decoder_tokenizer,
        ckpt_base=ckpt_base,
        permute_outp=[0, 2, 1],
    )

    trainer.fit(seq2seq, datamodule=dataset)
    # trainer.test(seq2seq, datamodule=dataset)

Outputs of script demonstrating the issue

During training:

Output of encoder (to demonstrate that there is a difference per input):

tensor([[[-0.1545,  0.0785,  0.4573,  ..., -0.3254,  0.5409,  0.4258],
         [ 0.2935, -0.1310,  0.4843,  ..., -0.4160,  0.8018,  0.2589],
         [ 0.0649, -0.5836,  1.9177,  ..., -0.3412,  0.2852,  0.8098],
         ...,
         [ 0.1109,  0.1653,  0.5843,  ..., -0.3402,  0.1081,  0.2566],
         [ 0.3011,  0.0258,  0.4950,  ..., -0.2070,  0.1684, -0.0199],
         [-0.1004, -0.0299,  0.4860,  ..., -0.2958, -0.1653,  0.0719]],

        [[-0.3105,  0.0351, -0.5714,  ..., -0.1062,  0.3461,  0.8927],
         [ 0.0727,  0.2580, -0.6962,  ...,  0.3195,  0.9559,  0.6534],
         [-0.6213,  0.9008,  0.2194,  ...,  0.1259,  0.1122,  0.7071],
         ...,
         [ 0.2667, -0.1453, -0.2017,  ...,  0.5667, -0.0772, -0.2298],
         [ 0.4050,  0.0916,  0.2218,  ...,  0.0295, -0.2065,  0.1230],
         [-0.1895,  0.0259, -0.1619,  ..., -0.1657, -0.0760, -0.6030]],

        [[-0.1366,  0.2778,  0.1203,  ..., -0.4764,  0.4009,  0.2918],
         [ 0.2401, -0.2308,  1.1218,  ..., -0.2140,  0.7054,  0.6656],
         [-0.7005, -0.9183,  1.6280,  ...,  0.2339, -0.1870,  0.0630],
         ...,
         [-0.0212, -0.2678,  0.0711,  ...,  0.2884,  0.3741, -0.2103],
         [-0.0058, -0.2364,  0.2587,  ...,  0.0689,  0.2010, -0.0315],
         [ 0.1869, -0.0784,  0.2257,  ..., -0.1498,  0.0935, -0.0234]],

        [[ 0.1023,  0.0532,  0.2052,  ..., -0.5335,  0.0676,  0.2436],
         [-0.2254,  1.0484, -0.1338,  ..., -0.9030, -0.1407, -0.2173],
         [-0.8384,  0.3990,  0.6661,  ..., -0.4869,  0.7780, -0.5461],
         ...,
         [ 0.4410,  0.1868,  0.6844,  ..., -0.2972, -0.1069, -0.1848],
         [-0.0021, -0.0537,  0.2477,  ...,  0.1877, -0.0479, -0.3762],
         [ 0.1981,  0.0980,  0.3827,  ...,  0.1449,  0.0403, -0.2863]]],
       grad_fn=<NativeLayerNormBackward>)

Training reference labels:

[
'pau @ @ schal @ @ preis 80 € / person auf basis von 2 person @ @ nen.', 
'ich finde es be @ @ denk @ @ lich, dass der bericht, den wir im ausschuss angenommen haben, so unterschiedlich ausgelegt wird.', 
'die globalisierung hat eine betrachtliche veranderung der bedeutung ge @ @ ok @ @ ultur @ @ eller regionen in der welt mit sich gebracht.', 
'falls sie eigentumer einer immobili @ @ e in andor @ @ ra sind, kontaktieren sie uns, um ihr apartment oder hotel hier auf @ @ zun @ @ ehem @ @ en.',
]

Training predictions after .generate() and .batch_decode() (garbage, but different per input):

[
'##exe int int int int fid fid fid fid fid fid fid fid fid fid fid fid lanz urn', 
'##schleschleually vno stadien stadien stadienherzherzherzherzherzherzherzherzherzherzherzherz', '##betrtghattkerlabend verpackungahmahm te te teila einfl einfl einflierende add adduff', 
'##reisreisviert fairrug ganze ganze ganze veh wz wz wz ihr x ihrverdverdverdverd',
]

During validation:

Input IDs to encoder:

tensor([[ 101, 1037, 3072,  ...,    0,    0,    0],
        [ 101, 3072, 1030,  ...,    0,    0,    0],
        [ 101, 2174, 1010,  ...,    0,    0,    0],
        [ 101, 5262, 1010,  ...,    0,    0,    0]])

Output of encoder (to demonstrate that there is a difference per input):

tensor([[[-0.2494, -0.2050, -0.2032,  ..., -1.0734,  0.1397,  0.4336],
         [-0.2473,  0.0091, -0.2359,  ..., -0.6884,  0.2158, -0.0761],
         [-0.5098, -0.1364,  0.7411,  ..., -1.0496, -0.0250, -0.2929],
         ...,
         [-0.1039, -0.2547,  0.2264,  ..., -0.2483, -0.2153,  0.0748],
         [ 0.2561, -0.3465,  0.5167,  ..., -0.2460, -0.1611,  0.0155],
         [-0.0767, -0.3239,  0.4679,  ..., -0.2552, -0.1551, -0.1501]],

        [[-0.3001,  0.0428, -0.3463,  ..., -0.6265,  0.3733,  0.3856],
         [-0.1463, -0.0212,  0.1447,  ..., -0.7843, -0.0542,  0.2394],
         [ 0.7481, -0.3762,  0.6301,  ...,  0.2269,  0.0267, -0.4466],
         ...,
         [ 0.3723, -0.2708,  0.2251,  ..., -0.0096, -0.0072, -0.2217],
         [ 0.4360, -0.1101,  0.3447,  ...,  0.0117, -0.0956, -0.1236],
         [ 0.3221, -0.1846,  0.3263,  ..., -0.0600, -0.0025, -0.1883]],

        [[-0.1365,  0.1746,  0.1038,  ..., -0.2151,  0.7875,  0.8574],
         [ 0.1072,  0.2133, -0.8644,  ...,  0.0739,  1.0464,  0.3385],
         [ 0.7204,  0.2680,  0.0991,  ..., -0.2964, -0.8238, -0.0604],
         ...,
         [ 0.2686, -0.0701,  0.8973,  ..., -0.0366, -0.2160,  0.0276],
         [ 0.2265, -0.2171,  0.4239,  ...,  0.0833, -0.0573,  0.0297],
         [ 0.0690, -0.2430,  0.4186,  ...,  0.0897, -0.0287,  0.0762]],

        [[ 0.0408,  0.2332, -0.0992,  ..., -0.2242,  0.6512,  0.4630],
         [ 0.3257,  0.1358, -0.3344,  ...,  0.0866,  1.0004, -0.0733],
         [ 0.6827,  0.3013,  0.0672,  ..., -0.2793, -0.8870, -0.0024],
         ...,
         [ 0.4291, -0.5344,  0.0134,  ...,  0.0439,  0.0617, -0.4433],
         [ 0.4847, -0.2888,  0.2942,  ...,  0.0153,  0.0121, -0.1231],
         [ 0.4725, -0.3132,  0.3458,  ..., -0.0207,  0.0517, -0.4281]]])

Validation reference labels:

[
'eine repub @ @ li @ @ kanische strategie, um der wieder @ @ wahl von obama entgegen @ @ zu @ @ treten', 
'die fuhrungs @ @ krafte der republi @ @ kaner rechtfertigen ihre politik mit der notwendigkeit, den wahl @ @ betrug zu bekampfen.', 
'allerdings halt das brenn @ @ an center letz @ @ teres fur einen my @ @ thos, indem es bekraftigt, dass der wahl @ @ betrug in den usa sel @ @ tener ist als die anzahl der vom bli @ @ tz @ @ schlag geto @ @ teten menschen.', 
'die rechtsan @ @ walte der republi @ @ kaner haben in 10 jahren in den usa ubrigens nur 300 falle von wahl @ @ betrug ver @ @ zeichnet.',
]

Validation predictions after .generate() and .batch_decode() (garbage, but the same per input):

[
'##schleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschle', 
'##schleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschle', 
'##schleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschle', 
'##schleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschleschle',

]

Expected behavior

I would expect the model to generate a different output per input, as during training time.

Thank you for your help!

Hopefully, it is something simple that I am missing.

patil-suraj commented 3 years ago

Hi @anicolson ,

We would love to help, but sadly when you post such a long script it will be very hard and time-consuming for us to take a look at. We're happy to assist if you could provide a short, precise, and complete code snippet that is based on Transformers Seq2SeqTrainer only. Here's our guide on how to request support.

Also from what I can see, seems like you are initializing bert encoder and bert decoder separately, you could directly instantiate it using the EncoderDecoder model class to get a seq2seq model. Here are two colab notebooks that show how to train EncoderDecoder models using Seq2SeqTrainer. The notebooks show how to fine-tune for summarization task, but could be easily adapted for translation as well.

Leverage BERT for Encoder-Decoder Summarization on CNN/Dailymail

Leverage RoBERTa for Encoder-Decoder Summarization on BBC XSum

anicolson commented 3 years ago

Thanks for your reply,

I am attempting to create a shorter version that is not so time-consuming.

Certainly, the EncoderDecoder is an attractive option if one is using natural language, but I would like to highlight that using BertGenerateDecoder allows the user to provide any sequence for cross-attention, even those derived from encoders that operate on modalities other than natural language, which I think is powerful.

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

yangyijune commented 1 year ago

Thanks for your reply,

I am attempting to create a shorter version that is not so time-consuming.

Certainly, the EncoderDecoder is an attractive option if one is using natural language, but I would like to highlight that using BertGenerateDecoder allows the user to provide any sequence for cross-attention, even those derived from encoders that operate on modalities other than natural language, which I think is powerful.

Hi, have you tackled the problem? I encounter the exactly same problem. Any cues?