Higher loss (9.5602 vs. 9.3164) was observed for the dtensor case, after 10 steps on the llama2 debug model. This happens even without applying rotary embedding, and the complex number multiplication issue mentioned in #267.
Note: to apply math attention with dtensor, one needs to set _allow_implicit_replication to true (because a non-dtensor mask will be generated if is_causal=True for SDPA).
This issue doesn't seem to be urgent, as math attention is only a fallback option for flash attention and memory-efficient attention.
Higher loss (9.5602 vs. 9.3164) was observed for the dtensor case, after 10 steps on the llama2 debug model. This happens even without applying rotary embedding, and the complex number multiplication issue mentioned in #267.
Note: to apply math attention with dtensor, one needs to set
_allow_implicit_replication
to true (because a non-dtensor mask will be generated ifis_causal=True
for SDPA).This issue doesn't seem to be urgent, as math attention is only a fallback option for flash attention and memory-efficient attention.