volcengine / veScale

A PyTorch Native LLM Training Framework
http://vescale.xyz
Apache License 2.0
553 stars 26 forks source link

[RFC] veScale: High-Level API for nD Parallel Training #39

Open leonardo0lyj opened 2 months ago

leonardo0lyj commented 2 months ago

TL'DR

tldr

Motivation

Our current APIs for nD Parallel Training are low-level and are kind of complex for common users ... Ideally, we want a simpler API at a high level like this:

Single Device Code

dataset = ...
data_loader = torch.utils.data.DataLoader(dataset, ...)

class Net(nn.Module):
    ...

def optimizer_fn(model):
    ...
    return torch.optim.Adam(model_param_groups, ...)

def lr_scheduler_fn(optimizer):
    ...
    return torch.optim.lr_scheduler.StepLR(optimizer, ...)

model = Net(...)
optimizer = optimizer_fn(model)
scheduler = lr_scheduler_fn(optimizer)

for epoch in range(10):
    for batch in data_loader:
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()
        optimizer.step()
    scheduler.step()

torch.save(model.state_dict(), "/path/to/checkpoint")
torch.save(optimizer.state_dict(), "/path/to/checkpoint")
torch.save(scheduler.state_dict(), "/path/to/checkpoint")

veScale High-Level API for nD Parallel Training

dataset = ...

### zero code change on model
class Net(nn.Module):
    ...

def optimizer_fn(model):
    ...
    return torch.optim.Adam(model_param_groups, ...)

def lr_scheduler_fn(optimizer):
    ...
    return torch.optim.lr_scheduler.StepLR(optimizer, ...)

### create giant model without OOM
model = vescale.deferred_init(Net, ...)

### generate plan of nD parallel training under user constraints
# $ constraints = { "pipeline_parallel.split_method" : "flops",
# $                 "tensor_parallel.sharding_policy" : "megatron"  }
plan = vescale.generate_plan(constraints, model)
# $ print(plan)
# $   pipeline_parallel.split_points : ["layer1", "layer3", ...]
# $   tensor_parallel.sharding_plan : { "layer2.weight" : [Shard(dim=0)], ... }

### create nD parallel model and optimizer, specified by the plan
model, optimizer, scheduler, data_loader = vescale.parallelize(plan, model, optimizer_fn, lr_scheduler_fn, dataset)

### zero code change on training loop
for epoch in range(10):
    for batch in data_loader:
        optimizer.zero_grad()
        ### trains nD parallel model as if on single device
        loss = model(batch)
        loss.backward()
        optimizer.step()
    scheduler.step()

### saves nD parallel model and optimizer
vescale.save("/path/to/checkpoint", { "plan": plan, "model" : model, "optimizer" : optimizer, "lr_scheduler": scheduler })

Idea

Feedbacks are all we need : )

(image source: 1 2)