unslothai / unsloth

Finetune Llama 3, Mistral, Phi & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
12.82k stars 835 forks source link

sliding_window shouldn't be applied when flash_attn not installed? #680

Open rossbm opened 3 weeks ago

rossbm commented 3 weeks ago

I've been finetuning unsloth/Phi-3-mini-4k-instruct-bnb-4bit with a T4, which doesn't support flash attention, so I don't have it installed.

During evaluation, I've been running into the following error:

File /anaconda/envs/text2text-tagger/lib/python3.11/site-packages/unsloth/models/llama.py:218, in LlamaAttention_fast_forward_inference(self, hidden_states, past_key_value, position_ids, do_prefill, attention_mask)
    216     A = torch.matmul(A, Vnn, out = Qn)
    217 else:
--> 218     A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
    219 pass
    220 A = A.transpose(1, 2)

RuntimeError: The expanded size of the tensor (2047) must match the existing size (2956) at non-singleton dimension 3.  Target sizes: [2, 32, 1, 2047].  Tensor sizes: [2, 1, 1, 2956]

The batch that is being evaluated at this point has 2955 tokens. However, unsloth/Phi-3-mini-4k-instruct-bnb-4bit should support sequence lengths of 4096 tokens, and I make certain to set max_seq_length to 4096 when initializing the model.

Looking through the model config for unsloth/Phi-3-mini-4k-instruct-bnb-4bit, I see sliding_window": 2048, which would be the only place that a length of 2048 (or 2047) would be coming from.

In: https://github.com/unslothai/unsloth/blob/933d9fe2cb2459f949ee2250e90a5b610d277eab/unsloth/models/llama.py#L189, we have: if sliding_window is not None and kv_seq_len > sliding_window:

However, in https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py, there's a check if flash_attn is installed and and supports a sliding window:

# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
# if is_flash_attn_2_available():
_flash_supports_window_size = False
try:
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

    _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
except ImportError as error:
    logger.warning(
        f"`flash-attention` package not found, consider installing for better performance: {error}."
    )
    if not _flash_supports_window_size:
        logger.warning(
            "Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
        )

before the sliding window is used:

       use_sliding_windows = (
            _flash_supports_window_size
            and getattr(self.config, "sliding_window", None) is not None
            and kv_seq_len > self.config.sliding_window
        )

Sure enough, when I set model.config.sliding_window = 10_000 I am able to successfully call model.generate() on the batch that was giving me the RuntimeError: The expanded size of the tensor (2047) ... error.

So I think that the solution is to update if sliding_window is not None and kv_seq_len > sliding_window: to check if flash-attention is installed and supports window size, similar to what phi-3 is doing.

rossbm commented 3 weeks ago

I've tried running on another VM where I've installed flash_attn, but I'm still getting the error. Maybe the issue is that the slicing tokens aren't being applied to the attention mask.

From https://github.com/unslothai/unsloth/blob/main/unsloth/models/llama.py

if sliding_window is not None and kv_seq_len > sliding_window:
        # From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
        slicing_tokens = 1 - sliding_window
        Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
        Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()

While in https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py and https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py we have:

                if attention_mask is not None:
                    attention_mask = attention_mask[:, slicing_tokens:]
                    attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
danielhanchen commented 2 weeks ago

@rossbm Much apologies on the delay - my bro and I just relocated to SF, so sorry on the delay - appreciate the investigation as well!

I shall check if I'm doing inference on SWAs correctly :) Thanks for the report!