Closed patemotter closed 2 months ago
Adds Ragged Attention Pallas kernel as an option when performing autoregressive attention. By default this is disabled and does not interfere with the existing AR attention. It can be enabled by the CLI argument use_ragged_attention=true.
use_ragged_attention=true
Improvements are most noticeable when
quantize_kvcache=false
ar_cache_axis_order = prefill_cache_axis_order = "0,2,1,3"
max_prefill_predict_length
max_target_length
Adds Ragged Attention Pallas kernel as an option when performing autoregressive attention. By default this is disabled and does not interfere with the existing AR attention. It can be enabled by the CLI argument
use_ragged_attention=true
.Improvements are most noticeable when
quantize_kvcache=false
ar_cache_axis_order = prefill_cache_axis_order = "0,2,1,3"
max_prefill_predict_length
andmax_target_length
increase.