Closed brisker closed 7 months ago
I think the code is right. All of query_states
, key_states
, and value_states
are quantized before matrix multiplication.
You can print their shape before quantization to get a more direct and clear observation.
the matmul process between query_states and key_states here has transpose function on key_states:
so in this case, the per_token quantization of activations on the [-1] dimension here fits well.
But in the matmul process between attention and value_states here, _there is no transpose function on value_states after pertoken quantization function :
this causes the matmul between (a,b) and (c,d), former is quantized at
a
dimension but the latter is quantized atc
dimension, which seems wrong( I mean the latter should be quantized atd
dimension so that the quantized matmul can be really accelerated on hardware)