hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.63k stars 4.33k forks source link

[PROPOSAL]: refactor core API of Engine #2975

Closed ver217 closed 1 year ago

ver217 commented 1 year ago

Proposal

Motivation

  1. Current initialization process is difficult and hard to maintain. It contains hundreds of hard code if-else, which is hard to read and modify.
  2. Current Engine is hard to use. The usage is very different from native torch, and users may take some effort to learn before starting their first applications.
  3. Current Engine is not flexible. It relies on a configuration file or dict and a global context. If we want to run two models with different parallelism method, it's hard to implement this now. It also only supports single model training, which cannot support some famous RL like PPO.
  4. Too many legacy code. Gemini and auto-parallelism both have another entry points instead of Engine.

Design

We keep engine as the main entry point of colossalai training.

image

Engine has 6 main components:

Engine's features include:

Engine is not a singleton, though in the most cases single engine is enough.

Possible sample code (pseudo-code)

# create engine
plugin = GeminiPlugin(stage=3, ...)
engine = Engine(precision='fp16', parallelism_plugin=plugin)

# initialize models, optimizers, lr schedulers
model, optimizer, lr_scheduler = engine.initialize(model, optimizer, lr_scheduler)
# or multi-models
actor, critic, actor_optimizer, critic_optimizer = engine.initialize(actor, critic, actor_optimizer, critic_optimizer)

# forward backward
outputs = model(inputs)
engine.backward(loss, optimizer)
optimizer.step()

# run pipeline (another paradigm)
engine.execute_pipeline(data_iter, model, criterion, optimizer, ...)
optimizer.step()

# HF models generation
sequences = model.generate(input_ids)

# IO Support 2 styles:
# 1. torch style (target path is a file)
# 2. Huggingface style (target path is a directory)
# torch style (don't consider checkpoint size, maybe OOM as for large models)
engine.load(model, 'model.pt', plan='torch')
engine.save(optimizer, 'optimizer.pt', plan='torch')

# huggingface style (save checkpoint in chunks)
engine.save(model, 'checkpoint/gpt2', max_file_size_gb=10, plan='huggingface')
engine.load(optimizer, 'checkpoint/gpt2', plan='huggingface')

Single-model supervised learning train loop without pipeline

colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
engine = Engine(precision='fp16',parallelism_plugin=plugin)

model = GPT2()
optimizer = Adam(model.parameters())
dataloader = Dataloader(Dataset)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()

model, optimizer, lr_scheduler, dataloader = engine.initialize(model, optimizer, lr_scheduler, dataloader)

for epoch in range(max_epochs):
    for input_ids, attention_mask in dataloader:
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs.logits, input_ids)
        engine.backward(loss, optimizer)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

Single-model supervised learning train loop with pipeline

colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
engine = Engine(precision='fp16',parallelism_plugin=plugin)

model = GPT2()
optimizer = Adam(model.parameters())
dataloader = Dataloader(Dataset)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()

model, optimizer, lr_scheduler, dataloader = engine.initialize(model, optimizer, lr_scheduler, dataloader)

for epoch in range(max_epochs):
    num_steps = len(dataloader)
    for step in range(num_steps):
        loss = engine.execute_pipeline(dataloader, model, criterion, optimizer, return_loss=True, return_outputs=False)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

Multi-model RL train loop without pipeline

colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
engine = Engine(precision='fp16',parallelism_plugin=plugin)

actor = GPT2Actor()
critic = GPT2Critic()
actor_optim = Adam(actor.parameters())
critic_optim = Adam(critic.parameters())
actor_loss_fn = ActorLoss()
critic_loss_fn = CriticLoss()

actor, critic, actor_optim, critic_optim = engine.initialize(actor, critic, actor_optim, critic_optim)

for epoch in range(max_epochs):
    for experience in replay_buffer:
        action_log_probs = actor(experience.sequences)
        actor_loss = actor_loss_fn(action_log_probs, experience.old_log_probs, experience.adv)
        engine.backward(actor_loss, actor_optim)
        actor_optim.step()
        actor_optim.zero_grad()

        values = critic(experience.sequences)
        critic_loss = critic_loss_fn(values, experience.old_values, experience.reward)
        engine.backward(loss, critic_optim)
        critic_optim.step()
        critic_optim.zero_grad()

Possible class definition (pseudo-code)

class Engine:
    def __init__(self, 
                 device: Union[str, torch.device] = 'cuda',
                 precision: str = 'fp32',
                 grad_clipping_type: str = 'norm',
                 grad_clipping_value: float = 0.0,
                 parallelism_plugin: Optional[ParallelismPlugin] = None) -> None:
        # sanity check
        assert device in parallelism_plugin.supported_devices
        assert precision in parallelism_plugin.supported_precisions

        self.parallelism_plugin = parallelism_plugin
        self.accelerator = None
        self.precision_bolt = None
        if not parallelism_plugin.control_device:
            self.accelerator = Accelerator(device)
        if not parallelism_plugin.control_precision:
            self.precision_bolt = PrecisionBolt(precision, grad_clipping_type, grad_clipping_value)
        self.environment_table = EnvironmentTable(parallelism_plugin.device_mesh_shapes)
        self.checkpoint_io = CheckpointIO(self.parallelism_plugin, self.precision_bolt, self.accelerator)

    def initialize(self, *args: Union[Module, Optimizer, LRScheduler, DataLoader]) -> List[Union[Module, Optimizer, LRScheduler, DataLoader]]:
        rets = []
        for arg in args:
            if isinstance(arg, Module):
                arg = self.parallelism_plugin.setup_model(arg, self.environment_table.device_mesh_pool)
                if not self.parallelism_plugin.control_precision:
                    arg= self.precision_bolt.setup_model(arg)
                if not self.parallelism_plugin.control_device:
                    arg = self.accelerator.setup_model(arg)
            elif isinstance(arg, Optimizer):
                arg = self.parallelism_plugin.setup_optimizer(arg)
                if not self.parallelism_plugin.control_precision:
                    arg = self.precision_bolt.setup_optimizer(arg)
            else:
                # TODO
                pass
            rets.append(arg)
       return rets

    def backward(self, loss: Tensor, optimizer: Optimizer) -> None:
        # do backward when not using pipeline
        if not self.parallelism_plugin.control_precision:
            loss = self.precision_bolt.scale_loss(loss)
        optimizer.backward(loss)

    def execute_pipeline(self, data_iter: Iterator, model: Module, criterion: Callable[[Inputs, Outputs], Tensor], optimizer: Optimizer, return_loss: bool = True, return_outputs: bool = False) -> Tuple[Optional[Tensor], ...]:
        # run pipeline forward backward pass
        # return loss or outputs if needed
        pass

    def no_sync(self, model: Module) -> Context:
        if not self.parallelism_plugin.support_no_sync:
            raise RuntimeError()
        return model.no_sync()

    def save(self, obj: Union[Module, Optimizer, LRScheduler], path_like: str, plan: str = 'torch', **kwargs) -> None:
        pass

    def load(self, obj: Union[Module, Optimizer, LRScheduler], path_like: str, plan: str = 'torch', **kwargs) -> None:
        pass

class EnvironmentTable:
    def __init__(self, intra_op_world_sizes: List[int]):
        self.world_size
        self.rank
        self.global_group
        self.device_mesh_pool # generate from intra_op_world_sizes

    @property
    def is_master(self) -> bool:
        pass

class Accelerator:
    def __init__(self, device):
        self.device = device

    def setup_model(self, model) -> nn.Module:
        pass

class PrecisionBolt:
    def __init__(self, precision_type: dtype, grad_clipping_type: str, grad_clipping_value: float):
        self.precision_type = precision_type
        self.grad_clipping_type = grad_clipping_type
        self.grad_clipping_value = grad_clipping_value

    def setup_model(self, model) -> nn.Module:
        pass

    def setup_optimizer(self, optimizer) -> Optimizer:
        # inject grad clipping and unscale loss
        pass

    def scale_loss(self, loss) -> torch.Tensor:
        pass

class ParallelismPlugin:
    @property
    def supported_devices(self) -> List[device]:
        pass

    @property
    def supported_precisions(self) -> List[str]:
        pass

    @property
    def control_precision(self) -> bool:
        pass

    @property
    def control_device(self) -> bool:
        pass

    @property
    def support_no_sync(self) -> bool:
        pass

    def setup_model(self, model, device_mesh_pool) -> Module:
        pass

    def setup_optimizer(self, optimizer) -> Optimizer:
        pass

    def setup_dataloader(self, dataloader) -> Dataloader:
        pass

    @property
    def device_mesh_shape(self) -> List[Tuple[int, ...]]:
        pass

Futher work

Huggingface/accelerate and Lightning/fabric may have similar design.

We may provide colossalai plugin / strategy to these libs.

Self-service

FrankLeeeee commented 1 year ago

@ver217 There are some suggestions regarding the API design:

  1. Rename PrecisionBolt as bolt is rather unclear
  2. Don't name the plugin as ParallelismPlugin as we can extend to other features such as quantization. I think it is enough to simply name it as Plugin.
  3. Don't name it as engine as it can be a bit misleading as discussed earlier on.
FrankLeeeee commented 1 year ago

This issue is migrated to #3046 , thus, I will close it for now and all discussions will take place in #3046 .