foundation-model-stack / fms-fsdp

🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
https://pytorch.org/docs/stable/fsdp.html
Apache License 2.0
114 stars 18 forks source link

maximize mistral throughput #63

Open aldopareja opened 3 months ago

aldopareja commented 3 months ago

Instructlab backend currently focuses on mistral fine tuning and I'm trying to maximize throughput for that. If anyone notices anything obvious or has any suggestions I'd truly appreciate it. @raghukiran1224 mentioned that posting an issue here would potentially help.

I'm currently seeing a throughput of around 90 samples per second at max context length of 2600 tokens (but on average is only around 500 tokens) on 80 GPUs in prod vela. On a single node I get a throughput of around 11.2 samples per second and the best way is to do shard_op (zero stage 2) and no gradient checkpointing.

The main bottleneck is the networking, so having the largest possible batch size maximizes throughput since the networking communication bottlenecks almost at the same rate regardless of the bs. For such reason I ended up using HYBRID_SHARD_ZERO2 and enabling checkpointing to get a bs of 20 samples per gpu at 2600 max length.

These are the main parts to look at:

Model setup

Currently using HYBRID_SHARD_ZERO2 but have experimented with all the possibilities. Couldn't get torch.compile to work. And had to enable gradient checkpointing to maximize batch size.

def setup_model(model_name, tokenizer):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )
    if len(tokenizer) > model.config.vocab_size:
        print(
            f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size"
        )
        model.resize_token_embeddings(
            int(8 * math.ceil(len(tokenizer) / 8.0))
        )  # make the vocab size multiple of 8 for sharding the embedding layer.

    assert model.__class__.__name__ in [
        "MistralForCausalLM"
    ], f"Model class name: {model.__class__.__name__} is not supported."

    model = FSDP(
        model,
        auto_wrap_policy=partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={
                MistralDecoderLayer,
            },
        ),
        # use_orig_params=True,
        limit_all_gathers=True,
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.bfloat16,
            buffer_dtype=torch.bfloat16,
        ),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        sharding_strategy=ShardingStrategy._HYBRID_SHARD_ZERO2,
        device_id=torch.cuda.current_device(),
    )
    model.gradient_checkpointing_enable()
    # model = torch.compile(model)
    return model

training loop

importantly the use_cache=False, even though it is commented out gets set to True because only the gradient checkpointing works.

        for batch in train_loader:
            start = time.time()
            for k in batch:
                batch[k] = batch[k].to(local_rank)

            output = model(
                **batch,
                # use_cache=False,
            )

            loss = output["loss"]
            loss.backward()

            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
aldopareja commented 3 months ago

Throughput test

Performed many experiments in prod cluster trying to get the best throughput using pytorch'sfsdp. In general, when the number of nodes increases above certain threshold the best throughput is handled by the HYBRID_SHARD approach, which does FULL_SHARD on a single node but each node has a full copy of the model, so most of the all_gather operations happen inside a node, with much faster networking than intra-node communication.

Also, it's not advisable to just increase the batch size until you run out of memory and then lowering just a bit because CUDA MALLOC RETRIES goes up and creates bottlenecks. it's important to find the maximum batch size that gets CUDA MALLOC RETRIES to at most 1 or 2. (which you get with torch.cuda.memory_summary()) even if the batch size is lower. Throughput is highest there.

These are the results of the experiments I did to test:

LEN 4600

FULL_SHARD and PRE and GRAD_CKPT:
SHARD_OP and PRE and GRAD_CKPT:
SHARD_GRAD_OP and PRE:
FULL_SHARD and PRE
FULL_SHARD and POST -- 3 NODES
HYBRID_ZERO_2 and POST and GRAD_CKPT -- 3 NODES
HYBRID_SHARD and POST and GRAD_CKPT -- 3 NODES
SHARD_GRAD_OP and POST and GRAD_CKPT -- 3 NODES
FULL_SHARD and POST and GRAD_CKPT -- 3 NODES

CPU_OFFLOAD and POST and GRAD_CKPT -- 3 NODES

len 2048

FULL_SHARD, POST, GRAD_CKPT -- 3 NODES

len 2600

SHARD_GRAD_OP, POST, GRAD_CKPT --5 NODES
FULL_SHARD, POST, GRAD_CKPT -- 5 NODES
HYBRID (ZERO 3), PRE, GRAD_CKPT -- 5 NODES
HYBRID (ZERO 3), POST, GRAD_CKPT -- 5 NODES

len 2600 cache=True

SHARD_GRAD_OP, POST
SHARD_GRAD_OP, PRE
FULL_SHARD, PRE
ani300 commented 2 months ago

Hi, I've been working with @sahilsuneja1 on getting more throughput our of Mixtral for speculator training. I believe it'll also be useful for InstructLab testing, given the code you posted. Feel free to open a chat with the both of us if this is still relevant.