Closed vanbasten23 closed 1 week ago
Can someone show me why it is important to expose all the SplashAttention parameters independently?
We talked with Sharad early on and took his recommendation of tuning the flags in three blocks. So my current understanding is that is all that is needed to tune splash attention.
cc @gobbleturk @khatwanimohit
Thanks for the info. This PR is result from an initiative to tune SplashAttention Pallas kernel with model parameters using MaxText and RayTune on Trillium, to demonstrate the Trillium capability before its GA. One of the steps is to expose a full set of SplashAttention tunable parameters in MaxText.
Let me check with Sharad to see if it's worthwhile to expose all tunable parameters.
Can someone show me why it is important to expose all the SplashAttention parameters independently? We talked with Sharad early on and took his recommendation of tuning the flags in three blocks. So my current understanding is that is all that is needed to tune splash attention. cc @gobbleturk @khatwanimohit
Thanks for the info. This PR is result from an initiative to tune SplashAttention Pallas kernel with model parameters using MaxText and RayTune on Trillium, to demonstrate the Trillium capability before its GA. One of the steps is to expose a full set of SplashAttention tunable parameters in MaxText.
Let me check with Sharad to see if it's worthwhile to expose all tunable parameters.
Got the response and comment it here for future reference. It seems we can still play with some other parameters such as q/k/v_layout. So it may be worth to do it.
Full tunable parameters for SplashAttention are here
Test plan: $ python3 MaxText/train.py MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs://xiowei-bucket-ctmd dataset_type=synthetic steps=10
cc: @miladm