Open mpatel31415 opened 1 day ago
Reposting from Slack: Here are the differences in their configs (the first values are for Mistral, and the second values are for Llama):
{'block_size': 4096,
'padded_vocab_size': 32000,
'rope_base': 10000,
'sliding_window_layer_placing': 1,
'sliding_window_size': 4096,
'vocab_size': 32000
{'block_size': 8192,
'padded_vocab_size': 128256,
'rope_base': 500000,
'sliding_window_layer_placing': None,
'sliding_window_size': None,
'vocab_size': 128000}
The main difference here is that Mistral config uses sliding window attention mask (sliding_window_layer_placing) so the "if self.apply_sliding_window_attention" codepath is taken. If attn_mask is not None probably a different kernel is chosen in cuDNN that could be less efficient compared to a kernel available in PyTorch.
@vedaanta's response:
On the above topic, cudnn does have performant kernels for sliding_window_attention on Hopper. In the code path linked by Ivan, the mask is being materialized. And cudnn kernels today are much slower when reading from gmem before the softmax. BUT If thunder can somehow pass the sliding window size to cudnnex, it can just generate the mask on the fly and provide good speedup. (maybe a sliding_window_size in ltorch.sdpa just like the is_causal option)
Let's measure what speedup can be achieved by not materializing the mask and using the special attributes for implicit masking instead. We have a parametrized SDPA benchmark https://github.com/Lightning-AI/lightning-thunder/blob/79e59d0c5c5f8aa8ef80eb31f3fe918466d64c1c/thunder/benchmarks/targets.py#L386 and its function definition is living here https://github.com/Lightning-AI/lightning-thunder/blob/79e59d0c5c5f8aa8ef80eb31f3fe918466d64c1c/thunder/benchmarks/__init__.py#L2634-L2636 It's straightforward to change "torch.nn.functional" -> "thunder.torch" to benchmark only Thunder implementation. Then the sliding window length parameter needs to be passed to the cuDNN executor (here's an example usage https://github.com/NVIDIA/cudnn-frontend/blob/de355c7094af70467f2b264f531ab5c5f4401c42/test/python_fe/test_mhas.py#L677).
🚀 Feature
Make Thunder + Mistral-7B-v0.1 as fast as Thunder + Llama3-8b (comparing to Eager mode).
Motivation
Below are data for:
The main difference is that Mistral-7B-v0.1 fits into 1 GPU, but Llama3-8b not, so for Llama3-8b we use distributed training.