Open aldopareja opened 3 months ago
Performed many experiments in prod cluster trying to get the best throughput using pytorch'sfsdp. In general, when the number of nodes increases above certain threshold the best throughput is handled by the HYBRID_SHARD approach, which does FULL_SHARD on a single node but each node has a full copy of the model, so most of the all_gather
operations happen inside a node, with much faster networking than intra-node communication.
Also, it's not advisable to just increase the batch size until you run out of memory and then lowering
just a bit because CUDA MALLOC RETRIES goes up and creates bottlenecks. it's important to find the maximum batch size that gets CUDA MALLOC RETRIES to at most 1 or 2. (which you get with torch.cuda.memory_summary()
) even if the batch size is lower. Throughput is highest there.
These are the results of the experiments I did to test:
Hi, I've been working with @sahilsuneja1 on getting more throughput our of Mixtral for speculator training. I believe it'll also be useful for InstructLab testing, given the code you posted. Feel free to open a chat with the both of us if this is still relevant.
Instructlab backend currently focuses on mistral fine tuning and I'm trying to maximize throughput for that. If anyone notices anything obvious or has any suggestions I'd truly appreciate it. @raghukiran1224 mentioned that posting an issue here would potentially help.
I'm currently seeing a throughput of around 90 samples per second at max context length of 2600 tokens (but on average is only around 500 tokens) on 80 GPUs in prod vela. On a single node I get a throughput of around 11.2 samples per second and the best way is to do shard_op (zero stage 2) and no gradient checkpointing.
The main bottleneck is the networking, so having the largest possible batch size maximizes throughput since the networking communication bottlenecks almost at the same rate regardless of the bs. For such reason I ended up using HYBRID_SHARD_ZERO2 and enabling checkpointing to get a bs of 20 samples per gpu at 2600 max length.
These are the main parts to look at:
Model setup
Currently using HYBRID_SHARD_ZERO2 but have experimented with all the possibilities. Couldn't get torch.compile to work. And had to enable gradient checkpointing to maximize batch size.
training loop
importantly the use_cache=False, even though it is commented out gets set to True because only the gradient checkpointing works.