Instead of adding precision to each dot_general or einsum(), we could use default_matmul_precision, which has similar 3 options. However I met issues during flash attention kernel call (here). So fall back to add precision to each matmul.
Test
base Mixtral 8x7b run (without changes, dtype=bfloat16, weight=f32): link - 2637.763 tokens/s/chip
base Mixtral 8x7b run (with changes, dtype=bfloat16, weight=f32): link - 2625.901 tokens/s/chip (0.5% regression)
matmul_precision=float32 run (with changes, dtype=bfloat16, weight=f32): link - 2588.229 tokens/s/chip (1.88% regression)
matmul_precision=float32 run (with changes, dtype=f32, weight=f32): OOM 14GB link - makes sense
Description
Test