AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.54k stars 295 forks source link

Expose all SplashAttention tunable parameters in the workload. #1016

Closed vanbasten23 closed 1 week ago

vanbasten23 commented 2 weeks ago

Full tunable parameters for SplashAttention are here

Test plan: $ python3 MaxText/train.py MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs://xiowei-bucket-ctmd dataset_type=synthetic steps=10

cc: @miladm

vanbasten23 commented 2 weeks ago

Can someone show me why it is important to expose all the SplashAttention parameters independently?

We talked with Sharad early on and took his recommendation of tuning the flags in three blocks. So my current understanding is that is all that is needed to tune splash attention.

cc @gobbleturk @khatwanimohit

Thanks for the info. This PR is result from an initiative to tune SplashAttention Pallas kernel with model parameters using MaxText and RayTune on Trillium, to demonstrate the Trillium capability before its GA. One of the steps is to expose a full set of SplashAttention tunable parameters in MaxText.

Let me check with Sharad to see if it's worthwhile to expose all tunable parameters.

vanbasten23 commented 1 week ago

Can someone show me why it is important to expose all the SplashAttention parameters independently? We talked with Sharad early on and took his recommendation of tuning the flags in three blocks. So my current understanding is that is all that is needed to tune splash attention. cc @gobbleturk @khatwanimohit

Thanks for the info. This PR is result from an initiative to tune SplashAttention Pallas kernel with model parameters using MaxText and RayTune on Trillium, to demonstrate the Trillium capability before its GA. One of the steps is to expose a full set of SplashAttention tunable parameters in MaxText.

Let me check with Sharad to see if it's worthwhile to expose all tunable parameters.

Got the response and comment it here for future reference. It seems we can still play with some other parameters such as q/k/v_layout. So it may be worth to do it.