huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.25k stars 26.09k forks source link

Getting near constant training loss, T5 not learning anything? #13033

Closed prikmm closed 3 years ago

prikmm commented 3 years ago

Environment info

Who can help

@patil-suraj, @sgugger

Information

Model I am using (Bert, XLNet ...): T5

I am trying to finetune T5 on XSum using TPU, but getting near constant training loss and constant validation loss. It's like the model is not learning anything. I tried t5-small, t5-base, t5-large(on kaggle), google/t5-v1_1-small, google/t5-v1_1-base, but all are giving constant training loss. I applied all the tips from T5 Finetuning Tips thread like using AdaFactor etc. Now, @patil-suraj was able to to train t5-large with max_input_length=512, max_output_length=64 and batch_size=8. But, I was also able to train t5-large with max_input_length=1024, max_output_length=128 and batch_size=128 on kaggle. I don't know why this is happening. Is it because of some of the layers are frozen by default?

Loss for t5-small: loss

Eval Loss for 't5-small`: eval_loss

The problem arises when using:

I have modified the script

The tasks I am working on is:

To reproduce

Colab Link

Code bits from Colab for overview:

Dataset Creation:

class MyXSum(Dataset):

    def __init__(self, Config, tokenizer, split_type):  

        main_ds = load_dataset("xsum")
        self.model_name = Config.model_checkpoint
        self.dataset = main_ds[split_type]
        self.tokenizer = tokenizer

        if split_type in set(["validation", "test"]):
            self.required_columns =["input_ids", "attention_mask", "labels"]
            if split_type == "validation":
                num_samples = 20
            else:
                num_samples = 20
        else:
            self.required_columns = ["input_ids", "attention_mask",
                                     #"decoder_input_ids",
                                     "decoder_attention_mask",
                                     "labels"
                                    ]
            num_samples = None

        if num_samples:
            self.dataset = self.dataset.select(list(range(0, num_samples)))

    def __len__(self):
        return self.dataset.shape[0]

    def preprocess_function(self, examples):

        _inputs = ["summarize: " + examples["document"]]
        _target = ["<pad>" + examples["summary"]]

        model_inputs = self.tokenizer(_inputs, max_length=512,
                                      truncation=True, padding="max_length",
                                      return_tensors="pt")

        # Setup the tokenizer for targets
        with self.tokenizer.as_target_tokenizer():
            labels = self.tokenizer(_target, max_length=64,
                                    truncation=True, padding="max_length",
                                    return_tensors="pt")

        model_inputs = {
            "input_ids": model_inputs["input_ids"].squeeze(),
            "attention_mask": model_inputs["attention_mask"].squeeze(),
            "decoder_input_ids": labels["input_ids"].squeeze(),
            "decoder_attention_mask": labels["attention_mask"].squeeze(),
            "labels": labels["input_ids"].squeeze(),
        }

        model_inputs = {k: model_inputs[k] for k in self.required_columns}

        return model_inputs

    def __getitem__(self, index):

        return self.preprocess_function(self.dataset[index])

Model Training:

@dataclass
class T2TDataCollator(DataCollatorWithPadding):

    def collate_batch(self, batch: List) -> Dict[str, torch.Tensor]:
        """
        Take a list of samples from a Dataset and collate them into a batch.
        Returns:
            A dictionary of tensors
        """
        input_ids = torch.stack([example['input_ids'] for example in batch])
        labels = torch.stack([example['decoder_input_ids'] for example in batch])
        labels[labels[:, :] == 0] = -100
        attention_mask = torch.stack([example['attention_mask'] for example in batch])
        decoder_attention_mask = torch.stack([example['decoder_attention_mask'] for example in batch])

        return {
            'input_ids': input_ids.squeeze(), 
            'attention_mask': attention_mask.squeeze(),
            'labels': labels.squeeze(), 
            'decoder_attention_mask': decoder_attention_mask.squeeze()
        }

model = AutoModelForSeq2SeqLM.from_pretrained(Config.model_checkpoint)
model.train()

WRAPPED_MODEL = xmp.MpModelWrapper(model)
optimizer = Adafactor(model.parameters(), scale_parameter=False, 
                      relative_step=False, warmup_init=False,
                      lr=1e-3)
lr_scheduler = AdafactorSchedule(optimizer)

data_collator = T2TDataCollator(tokenizer=tokenizer) 

train_ds = torch.load(Config.train_ds_path)
valid_ds = torch.load(Config.valid_ds_path)
test_ds = torch.load(Config.test_ds_path)

def _mp_fn(index):
    device = xm.xla_device()

    model = WRAPPED_MODEL.to(device)

    print("Loading datasets... ", end="")

    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=3,
        warmup_steps=0,
        evaluation_strategy="epoch",
        save_strategy="no",
        weight_decay=0.0,
        logging_dir="./log",
        #eval_steps=Config.eval_steps,
        logging_steps=50,
        per_device_train_batch_size=128,
        per_device_eval_batch_size=4,
    )

    #trainer = Seq2SeqTrainer(
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_ds,
        eval_dataset=valid_ds,
        optimizers=(optimizer, lr_scheduler),
    )
    trainer.place_model_on_device = False
    trainer.train()

xmp.spawn(_mp_fn, start_method="fork")

Expected behavior

Proper Finetuning of T5

CryptoSalamander commented 3 years ago

In my case, TPU's BF16 datatype caused a fixed loss value. did you use BF16 for training?

prikmm commented 3 years ago

In my case, TPU's BF16 datatype caused a fixed loss value. did you use BF16 for training?

Hey @CryptoSalamander, thanks for your reply. I finally found out the issue. My LR was 0.0, I was under the impression that, AdaSchedule would use the lr in optimizer and change with every step. But, when we use AdaSchedule, we have to pass in the initial_lr or it will default to 0.0 and since relative updates were false (as per the recommendation), the lr remained constant at 0.0.