xrsrke / pipegoose

Large scale 4D parallelism pre-training for 🤗 transformers in Mixture of Experts *(still work in progress)*
MIT License
76 stars 17 forks source link

Trainer #18

Open xrsrke opened 10 months ago

xrsrke commented 10 months ago

Notes

APIs

Trainer

from pipegoose.trainer import Trainer, TrainingArguments

config = {
    "tensor_parallelism": {"parallel_size": 2},
    "pipeline_parallelism": {
        "parallel_size": 4,
        "params": {"num_microbatches": 5}
    },
    "data_parallelism": {
        "parallel_size": 2,
        "zero_1": True
    },
    "mixed_precision": {"fp16": True}, # or bf16
    "fusion": {
        "optim": True,
        "model": True
    }
}

args = TrainingArguments(
    optim="adam",
    learning_rate=1e-3,
    lr_scheduler="",
    num_train_epochs=100,
    num_eval_steps=50,
    seed=42,
    config=config
)

trainer = Trainer(
    model=model, # loaded from `transformers`
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    callbacks=[PrintResultCallback(), SaveCheckpointCallback()]
)

trainer.train()
trainer.eval()

Trainer Callback

from pipegoose.trainer import Callback

class LoggingCallback(Callback):
    def on_train_start(
        self, trainer, model, optim,
        train_dataloader, eval_dataloader
    ):
        print("Training is starting")

    def on_train_end(
        self, trainer, model, optim,
        train_dataloader, eval_dataloader
    ):
        print("Training is ending")

DistributedDataLoader

from torch.utils.data import DataLoader
from pipegoose.utils.data import DistributedDataLoader

dataloader = DataLoader(dataset, batch_size=1024, shuffle=False)
dataloader = DistributedDataLoader(dataloader, parallel_context)

TODOs

isamu-isozaki commented 10 months ago

I think I'll do this tonight since it seems the easiest

xrsrke commented 10 months ago

@isamu-isozaki Awesome, thank you! I will get back to you in a few hours with all the details!!

isamu-isozaki commented 10 months ago

@xrsrke I was thinking of maybe just inheriting from transformer's Trainer. wdyt?

xrsrke commented 10 months ago

@isamu-isozaki Nope, I just checked Trainer from transformers. They modified our model's devices and stuff. We prefer implementing our own so we can incorporate distributed logging and callback in a specific rank, ParallelMode... and future changes. I just added some demo code (link).

Also one note, we only apply a specific parallel mode based on the parallel_context. For example, if data_parallel_size is greater than 1, then we wrap the model with DataParallel.