jsrozner / t5_finetune

A simple example for finetuning HuggingFace T5 model. Includes code for intermediate generation.
27 stars 4 forks source link

Choice of AdamW vs AdaFactor? #1

Open FL33TW00D opened 3 years ago

FL33TW00D commented 3 years ago

Hi there, Wondering if you can share your reasoning behind using AdamW over AdaFactor?

Thanks for sharing the script. Regards, Chris

jsrozner commented 3 years ago

It was the huggingface default (same with the choice of epsilon at 1e-8, and learning rates in the 1e-4 to 1e-5 range). It could be worth trying other optimizers. Have you seen better results with AdaFactor?

FL33TW00D commented 3 years ago

It was the huggingface default (same with the choice of epsilon at 1e-8, and learning rates in the 1e-4 to 1e-5 range). It could be worth trying other optimizers. Have you seen better results with AdaFactor?

I've personally been using AdaFactor, based on the recommendations by Google and in the following thread you may have already seen: https://discuss.huggingface.co/t/t5-finetuning-tips/684/12

I plan to do a quantitative comparison of the optimizers soon.

Regards, Chris

jsrozner commented 3 years ago

I'd been meaning to read through that post and tune over optimizer as well! I think transformers finetune.py script defaults to Adam (and a lot of the notebooks also seem to use adam).

What params have you been using with adafactor? And how did you settle on them?

With adam I've settled on 1e-4 or 3e-4 LR and linear schedule with epsilon 1e-8 on a dataset of 90k train and 30k eval, but it begins to overfit quite quickly, after only about 13 epochs

FL33TW00D commented 3 years ago

Hi @jsrozner, I am not sure if you're still using this repo, but I used it as the baseline of my own and have been making incremental improvements over time, so thank you!

One of the main ways I've found to speed up training is implementing a collate_fn with dynamic padding and uniform batch lengths. If you're still experimenting with T5 thought I'd attach the snippet:

from torch.nn.utils.rnn import pad_sequence
def collate_batch(batch):
    """
    Take a list of samples from a Dataset and collate them into a batch.
    Returns:
        A dictionary of tensors
    """
    pad_token_id = 0
    src_ids = pad_sequence([sample['source_ids'] for sample in batch], batch_first=True, padding_value=pad_token_id)
    src_text = [sample['source_text'] for sample in batch]
    src_mask = pad_sequence([sample['source_mask'] for sample in batch], batch_first=True, padding_value=pad_token_id)

    tgt_ids = pad_sequence([sample['target_ids'] for sample in batch], batch_first=True, padding_value=pad_token_id)
    tgt_ids[tgt_ids[:, :] == 0] = -100
    tgt_mask = pad_sequence([sample['target_mask'] for sample in batch], batch_first=True, padding_value=pad_token_id)
    tgt_text = [sample['target_text'] for sample in batch]

    return {
        'source_ids': src_ids, 
        'target_ids': tgt_ids,
        'source_mask': src_mask, 
        "target_mask": tgt_mask,
        "source_text": src_text, 
        "target_text": tgt_text
    }`
jsrozner commented 3 years ago

cool, collate was on my feature list actually! and i'm glad you've found it useful!

i've also been making a lot of changes - i've made it considerably more modular so that everything inherits from a common abstract trainer baseclass.

i'll probably push the changes to this repo once my research calms down a little. i can incorporate the collate function then.

jsrozner commented 3 years ago

also, huggingface's transformer offers a batch_encode method that should take care of uniform padding and length

FL33TW00D commented 3 years ago

Hi @jsrozner, The reason I did it this way is following along with this following blog post: https://wandb.ai/pommedeterresautee/speed_training/reports/Train-HuggingFace-Models-Twice-As-Fast--VmlldzoxMDgzOTI

It means we no longer need to pad to max length when we are batch encoding, and can strategically take batches of similar length samples in order to reduce the amount of padding needed. This really accelerated my training as my mean length of sample is 48 tokens but max is 128.

Can we do this with batch_encode? Would be easier if so.

jsrozner commented 3 years ago

I wrote the following, using huggingface tokenizer to handle the batch encoding. It will pad to the max length in a batch.

This also substantially reduces the memory footprint from what I had before.

It means we no longer need to pad to max length when we are batch encoding, and can strategically take batches of similar length samples in order to reduce the amount of padding needed.

Initially I read this to mean that you intentionally collate batches that have similar length sequences, but that probably isn't what you'd want to do if there's any correlation between length and your objective, since then your batches would not be grouped in a truly random way?

This implementation does not attempt to group similarly sized batches together, so if there is a batch where the longest is 100 tokens and all others are 10, it will still pad to 100 for all of them. Huggingface offers a pad_to_second_longest, I think that can help avoid this problem.

For an even larger dataset, where the dataset itself won't fit easily into memory, we'd want to write an IterableDataset.

from __future__ import annotations

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, List, Dict

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer

log = logging.getLogger(__name__)

@dataclass
class DataSetEntry:
    src: str
    tgt: str

@dataclass
class DataLoaderConfig:
    shuffle: bool = True
    batch_size: int = 64
    num_workers: int = 4

@dataclass
class DatasetConfig:
    tokenizer: PreTrainedTokenizer
    max_examples: int = 1  # if not -1, will truncate
    src_len: int = 100
    tgt_len: int = 20

class ClueDatasetBatched(Dataset):
    def __init__(self,
                 dataset_config: DatasetConfig,
                 data_dir: str,
                 type_path):
        valid_type_paths = ["test", "train", "val"]
        assert type_path in valid_type_paths, f"Type path must be one of {valid_type_paths}"

        self.example_path = Path(data_dir) / type_path
        self.max_examples = dataset_config.max_examples

        # populated in build
        self._len = None        # the total number of examples
        self.data_list: Optional[List[DataSetEntry]] = None

        self._build()  # fill inputs, targets, max_lens

    def __len__(self):
        return self._len

    def __getitem__(self, index):
        return self.data_list[index]

    def _build(self):
        source_path = self.example_path.with_suffix(".source")
        target_path = self.example_path.with_suffix(".target")

        with open(source_path, 'r') as f_source, \
            open(target_path, 'r') as f_target:

            source_lines, target_lines = f_source.readlines(), f_target.readlines()

            # do length calcs
            source_ct, target_ct = len(source_lines), len(target_lines)
            assert source_ct == target_ct, f"Lengths don't match"
            if self.max_examples > 0:
                source_ct = min(self.max_examples, source_ct)
            self._len = source_ct

            self.data_list = []
            for idx in range(source_ct):
                src = source_lines[idx].strip()
                tgt = target_lines[idx].strip()
                self.data_list.append(DataSetEntry(src, tgt))

def collate_fn(tokenizer: PreTrainedTokenizer, batch_list: List[DataSetEntry]) -> Dict:
    src_text = [e.src for e in batch_list]
    tgt_text = [e.tgt for e in batch_list]

    tokenized_inputs = tokenizer(src_text, padding='longest', return_tensors='pt')
    tokenized_outputs = tokenizer(tgt_text, padding='longest', return_tensors='pt')

    source_ids = tokenized_inputs["input_ids"]
    target_ids = tokenized_outputs["input_ids"]
    src_mask = tokenized_inputs["attention_mask"]      # might need to squeeze
    target_mask = tokenized_outputs["attention_mask"]  # might need to squeeze

    # We cast these to torch.long in preprocess batch
    ret = {"source_ids": source_ids,
           "source_mask": src_mask,
           "target_ids": target_ids,
           "target_mask": target_mask,
           "source_text": src_text,
           "target_text": tgt_text}

    return ret

def get_dataloader_batched(tokenizer,
                           dataset_config: DatasetConfig,
                           dl_config: DataLoaderConfig,
                           data_dir,
                           type_path: str = None) -> DataLoader:

    def curried_collate_fn(input_list) -> Dict:
        return collate_fn(tokenizer, input_list)

    data_set = ClueDatasetBatched(dataset_config,
                                  data_dir=data_dir,
                                  type_path=type_path)
    dataloader = DataLoader(data_set,
                            batch_size=dl_config.batch_size,
                            shuffle=dl_config.shuffle,
                            num_workers=dl_config.num_workers,
                            collate_fn=curried_collate_fn)
    log.info(f'Dataset {type_path} loaded with size: {len(data_set)}')
    return dataloader
jsrozner commented 3 years ago

@FL33TW00D what'd you think about the new implementation?