saic-fi / MobileQuant

[EMNLP Findings 2024] MobileQuant: Mobile-friendly Quantization for On-device Language Models
Other
41 stars 4 forks source link

There is an error when using model.generation #9

Open npu-mi-ji opened 5 days ago

npu-mi-ji commented 5 days ago

in hf_model.py cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]

Assuming the number of input tokens is 48, in the prefill stage, cache_position=[0,1,..., 47], its shape is (48), and the shape of attention_mask is (1,1, 48, 48). In the decoder stage, past_length=48, cache_position=[48], and the shape of attention_mask is (1,1, 1,49), there will be a problem --> IndexError: index 48 is out of bounds for dimension 0 with size 1

Did I use it incorrectly?

It may need to be modified to: if cache_position is not None and past_key_value is None: causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]

fwtan commented 5 hours ago

in hf_model.py cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]

Assuming the number of input tokens is 48, in the prefill stage, cache_position=[0,1,..., 47], its shape is (48), and the shape of attention_mask is (1,1, 48, 48). In the decoder stage, past_length=48, cache_position=[48], and the shape of attention_mask is (1,1, 1,49), there will be a problem --> IndexError: index 48 is out of bounds for dimension 0 with size 1

Did I use it incorrectly?

It may need to be modified to: if cache_position is not None and past_key_value is None: causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]

Hi,

model.generate has only been tested in https://github.com/saic-fi/MobileQuant/blob/8964f919ff3aa3fcc21236dff21aa31111891a5a/mobilellm/utils/bench.py#L164

Here only the context length (i.e. past_length) and generation length (i.e. position_ids.shape[-1]) are set, the cache and mask variables are None by default and created automatically after that.

In your case, it looks like the max lengths for prefilling and decoding were different, i.e. 48 for prefilling, 49 for decoding, therefore the caches and masks had different shapes for different stages.

In our demo, we used consistent max lengths for prefilling and decoding, e.g. 1024 for prefilling, 1023 (context length) + 1 for decoding.