AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.54k stars 295 forks source link

Remove aqt einsum for dropping #1036

Open mailvijayasingh opened 1 week ago

mailvijayasingh commented 1 week ago

i) Remove AQTEinsum for 1st step where inputs are masked with dispatch mask. This will lead to failure in loading weights that are pre-quantized without capacity_factor > 0. which should not be the case. As capacity factor is only forward pass, we should be able to load weights.

ii) Remove AQTEinsum from last step where C axis is contracted, that is element wise op and AQT does not support the same.

iii) Pass top_k weights to generate_masks instead of softmax_prob.

The correctness tests done - when given capacity_factor = 2 (which is almost dropless). The rouge scores for 8x22B are

{'rouge1': 44.5123, 'rouge2': 24.5477, 'rougeL': 31.3922, 'rougeLsum': 42.0639, 'gen_len': 1269606, 'gen_num': 1200}

as compared to capacity_factor = -1

{'rouge1': 44.1463, 'rouge2': 24.2765, 'rougeL': 31.2481, 'rougeLsum': 41.6197, 'gen_len': 1229222, 'gen_num': 1200} which makes sense.

tested 8x22B with capacity factor 1.5

{'rouge1': 43.9813, 'rouge2': 24.1528, 'rougeL': 30.8425, 'rougeLsum': 41.5053, 'gen_len': 1312280, 'gen_num': 1200}

vipannalla commented 1 week ago

Removing "pull ready" and requesting review to avoid merging.