foundation-model-stack / avengers

Apache License 2.0
0 stars 0 forks source link

Telescoping cache precision/throughput issues #1

Open daviswer opened 1 month ago

daviswer commented 1 month ago

We have implemented a Triton kernel for matmul operations involving a telescoping cache in the telescoping-kernel branch. These kernels pass their respective correctness checks (also included), but deploying to our training pipeline is not straightforward because Triton does not support atomic-add in bf16 (see here).

We instead cast to fp16 before this op, but loss curves on a test llama3-1.8B model diverge when we do this:

loss_curve

Loss curves do not diverge when running the kernels in fp32. Unfortunately this sacrifices our speed gains. We're currently evaluating fp32 atomic-adds only, and will update here.

Running these matmuls in fp16 also breaks the vanilla pytorch code, so this is almost certainly a precision issue. If internal fp32 casting does not fix the diverging loss, can the kernel code be massaged to avoid these issues?

daviswer commented 1 month ago

Update: performing atomic-adds in fp32 produces the desired behavior, with minimal extra speed/memory overhead. Now it's a question of optimizing throughput: we're currently getting ~3850 tokens/sec/gpu for this particular training setup, compared to ~2550 for pure pytorch baseline and ~10600 for flash attn

loss_curve

daviswer commented 1 month ago

Update 2: it turns out that the way we implemented the forward pass around the above kernels also made it amenable to standard attention with a custom mask (visualized below for seq len 512). So we're now running telescoping cache training - stably and relatively quickly - at 8B scale, using memory-efficient attention through PyTorch SDPA (as SDPA-flash attention still doesn't support custom masks, apparently). Further speedups will be possible if we can enable Flash Attention with custom masks in this context.

image (4)