foundation-model-stack / fms-acceleration

🚀 Collection of libraries used with fms-hf-tuning to accelerate fine-tuning and training of large models.
Apache License 2.0
6 stars 12 forks source link

Mixture of Experts Training with Acceleration Library Plugin #69

Closed fabianlim closed 1 week ago

fabianlim commented 3 months ago

This PR adds a plug in for mixture of experts training, combining FSDP with expert parallel where the latter is borrowed from databricks megablocks

This implements the FSDP1 version of expert parallel from https://github.com/foundation-model-stack/moe-distributed-training

What is Expert Parallel?

Expert parallel is a form of model parallelism that applies to mixture-of-experts models.

Diagram of Data Parallel (e.g., FSDP) vs Expert Parallel

image

Performance

Benchmark Results

Full-Finetuning

Model Gpus TYPE mem_peak mem_alloc train_runtime mem improvement runtime improvement
Mixtral-8x7B-Instruct-v0.1 8 FSDP-only 54.8 G 44.0 G 4019 s baseline baseline
Mixtral-8x7B-Instruct-v0.1 8 our plugin 45.5 G 33.5G 996 s 33 % 4.00 x

NOTE: the train runtimes were collected with --skip_memory_metrics=True (huggingface default); setting this to False was only used to benchmark memory numbers, as is known to result in worser runtime measurements

NOTE: throughput numbers were 83 and 337 tokens per second, respectively.

Checkpoint Resumption

Checkpointing works as evidenced by correct training resumption behavior (see below):

image

Next steps

Implementation Details

Comparison with DeepSpeed MoE (DS-MoE)

Deepspeed also has support for mixture-of-expert sharding. Noting down some points here:

 def create_moe_param_groups(model):
    from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer

    parameters = {'params': [p for p in model.parameters()], 'name': 'parameters'}

    return split_params_into_different_moe_groups_for_optimizer(parameters)

optimizer_grouped_parameters = create_moe_param_groups(opt_model)

Updates to benchmark.py

We now also

name: accelerated-moe-megablocks
    framework_config: 
        - # without acceleration. <- NEW
        - moe-megablocks
    slow: True # <- NEW: will be ignored in unfiltered runs
    arguments:
        learning_rate: 5e-5
        torch_dtype: bfloat16
        accelerator_config: scripts/benchmarks/accelerator-config.json
        gradient_accumulation_steps: 16
        logging_steps: 1
        packing: False
        adam_epsilon: 1e-8
        model_name_or_path: 
            - 'mistralai/Mixtral-8x7B-Instruct-v0.1'

Checklist of items covered

Known Issues

torch.concat operation is dominating the load_sharded_experts_onto_device function.

Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    3.171    3.171  588.308  588.308 fms-acceleration/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py:204(shard_moe)
       32    0.219    0.007  584.905   18.278 fms-acceleration/plugins/accelerated-moe/src/fms_acceleration_moe/megablocks_utils/shard_moe_utils.py:155(load_sharded_experts_onto_device)
      128  567.666    4.435  567.666    4.435 {built-in method torch.concat}
       96    0.013    0.000   14.993    0.156 /workspace/mb/lib/python3.10/site-packages/torch/distributed/_tensor/api.py:507(distribute_tensor)
fabianlim commented 1 week ago

superceeded by #99