i) Remove AQTEinsum for 1st step where inputs are masked with dispatch mask. This will lead to failure in loading weights that are pre-quantized without capacity_factor > 0.
which should not be the case. As capacity factor is only forward pass, we should be able to load weights.
ii) Remove AQTEinsum from last step where C axis is contracted, that is element wise op and AQT does not support the same.
iii) Pass top_k weights to generate_masks instead of softmax_prob.
The correctness tests done -
when given capacity_factor = 2 (which is almost dropless). The rouge scores for 8x22B are
i) Remove AQTEinsum for 1st step where inputs are masked with dispatch mask. This will lead to failure in loading weights that are pre-quantized without capacity_factor > 0. which should not be the case. As capacity factor is only forward pass, we should be able to load weights.
ii) Remove AQTEinsum from last step where C axis is contracted, that is element wise op and AQT does not support the same.
iii) Pass top_k weights to generate_masks instead of softmax_prob.
The correctness tests done - when given capacity_factor = 2 (which is almost dropless). The rouge scores for 8x22B are
{'rouge1': 44.5123, 'rouge2': 24.5477, 'rougeL': 31.3922, 'rougeLsum': 42.0639, 'gen_len': 1269606, 'gen_num': 1200}
as compared to capacity_factor = -1
{'rouge1': 44.1463, 'rouge2': 24.2765, 'rougeL': 31.2481, 'rougeLsum': 41.6197, 'gen_len': 1229222, 'gen_num': 1200} which makes sense.
tested 8x22B with capacity factor 1.5
{'rouge1': 43.9813, 'rouge2': 24.1528, 'rougeL': 30.8425, 'rougeLsum': 41.5053, 'gen_len': 1312280, 'gen_num': 1200}