turboderp / exllama

A more memory-efficient rewrite of the HF transformers implementation of Llama for use with quantized weights.
MIT License
2.67k stars 214 forks source link

Is it possible to do batch generate? #252

Open fahadh4ilyas opened 10 months ago

fahadh4ilyas commented 10 months ago

So, I'm trying to do batch generate using code made by oogabooga in text generation webui by calling generate method of ExllamaHF. But, error was thrown. I guess because Exllama only recognize the first sequence of the batch and that raise an error when the GeneratorMixin is trying to concatenated the past sequence to the next token value.

Is there any way to do batch generation using exllama?

turboderp commented 10 months ago

ExLlama can do batch generation, but ExLlamaHF was written by ooba as a wrapper specifically for TGW. I don't know how batches are normally handled in TGW, so I can't really say why they're not being (as it seems) correctly forwarded to ExLlama.

fahadh4ilyas commented 10 months ago

Could you give an example input for batch generation? Because I recall that we need input_ids padded on left side, attention_mask, and position_ids set. I guess the attention_mask in Transformers means input_mask for exllama. But, how about position_ids?

turboderp commented 10 months ago

There's an example in example_batch.py. It just calls generate_simple with a list of input strings rather than a single string, and then it returns a list of outputs instead of a single output.

If you want to call the forward pass directly, you need to specify a mask to go along with the batch, yes. You would call tokenizer.encode on a list of strings with return_mask = True, and it will return an int tensor with a batch of right-aligned token IDs, along with a boolean mask tensor of the same shape (False to indicate padding tokens). Then send those two to the forward pass and the model will output a batch of logits, pretty much like a HF model.

One thing to note is that, in either case, the cache needs to accommodate the batch size you intend to use.

fahadh4ilyas commented 10 months ago

But how did it handle the position of sequence? Like for example, because the padding on the left, the value for sin and cos array must not start from the beginning. Because it will not see the bos token as the first value of the sequence. HF model handle it with position_ids. What about exllama? It seems that exllama just multiply the value of sin and cos to query and key states without seeing where is the position of sequence.

turboderp commented 10 months ago

Well, rotary position embeddings are supposed to be position-independent, i.e. they affect attention between positions m and n in a way that depends only on the difference, n - m. So padding the sequence on the left should have no effect as long as the padding tokens are masked out.

That's the theory, at least. It does seem to produce correct output, too. Whether it still holds in extreme cases, with rounding errors etc., hard to say. But I haven't had any issues in testing.

fahadh4ilyas commented 10 months ago

Okay, I'm gonna test it. Because let say there are two sequences that we want to batch generate together but the first sequence is extremely long and the second one is short. I'm afraid that the result generated for the second sequence will be impacted since the position of the first token of the sequence is near to the maximum sequence length.

fahadh4ilyas commented 10 months ago

@turboderp This is my script to do my test

import torch, os, sys
from pathlib import Path
from typing import Optional, List, Union
from transformers import AutoTokenizer
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.generation.utils import GenerationMixin, GenerationConfig
from transformers.modeling_outputs import CausalLMOutputWithPast

BASE_DIR = Path(os.path.abspath('.'))

class RelativeImport:
    def __init__(self, path):
        self.import_path = BASE_DIR / Path(path)

    def __enter__(self):
        sys.path.insert(0, str(self.import_path))

    def __exit__(self, exc_type, exc_value, traceback):
        sys.path.remove(str(self.import_path))

with RelativeImport('exllama'):
    from model import ExLlama, ExLlamaCache, ExLlamaConfig
    from lora import ExLlamaLora

class ExLlamaForCausalLM(GenerationMixin):

    def __init__(
        self,
        config: LlamaConfig,
        generation_config: GenerationConfig,
        exllama_config: ExLlamaConfig,
        model: ExLlama,
        lora: ExLlamaLora,
        **kwargs
    ):
        self.config = config
        self.generation_config = generation_config
        self.exllama_config = exllama_config
        self.model = model
        self.lora = lora

        self.main_input_name = 'input_ids'

    def can_generate(self):
        return True

    @property
    def device(self) -> torch.device:
        return torch.device(0)

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_ids': input_ids, **kwargs}

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is None:
            if past_key_values is None:
                past_key_values = ExLlamaCache(self.model, input_ids.shape[0])
                self.model.forward(input_ids[...,:-1], past_key_values, preprocess_only=True, lora=self.lora, input_mask=attention_mask[...,:-1].to(torch.bool))

            logits = self.model.forward(input_ids[...,-1:], past_key_values, lora=self.lora, input_mask=attention_mask[...,-1:].to(torch.bool)).to(input_ids.device)
        else:
            if past_key_values is None:
                past_key_values = ExLlamaCache(self.model, input_ids.shape[0])

            logits = self.model.forward(input_ids, past_key_values, lora=self.lora, input_mask=attention_mask.to(torch.bool))

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = torch.nn.CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, logits.shape[-1])
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits, past_key_values if use_cache else None)
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values if use_cache else None, loss=loss)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        gpu_split: Optional[str] = None,
        lora_path: Optional[Union[str, os.PathLike]] = None
    ):
        if isinstance(pretrained_model_name_or_path, str):
            pretrained_model_name_or_path = Path(pretrained_model_name_or_path)

        if isinstance(lora_path, str):
            lora_path = Path(lora_path)

        config = LlamaConfig.from_pretrained(pretrained_model_name_or_path)

        generation_config = GenerationConfig.from_pretrained(pretrained_model_name_or_path)

        exllama_config = ExLlamaConfig(pretrained_model_name_or_path / 'config.json')
        exllama_config.max_seq_len = config.max_position_embeddings
        if config.rope_scaling is not None:
            if config.rope_scaling['type'] == 'linear':
                exllama_config.compress_pos_emb = config.rope_scaling['factor']
            elif config.rope_scaling['type'] == 'dynamic':
                exllama_config.alpha_value = config.rope_scaling['factor']
                exllama_config.calculate_rotary_embedding_base()
        if gpu_split is not None:
            exllama_config.set_auto_map(gpu_split.replace(' ', ','))
            exllama_config.gpu_peer_fix = True

        if torch.version.hip:
            exllama_config.rmsnorm_no_half2 = True
            exllama_config.rope_no_half2 = True
            exllama_config.matmul_no_half2 = True
            exllama_config.silu_no_half2 = True

        weight_path = None
        for ext in ['.safetensors', '.pt', '.bin']:
            found = list(pretrained_model_name_or_path.glob(f"*{ext}"))
            if len(found) > 0:
                weight_path = found[-1]
                break
        assert weight_path is not None, f'could not find weight in "{pretrained_model_name_or_path}"'

        exllama_config.model_path = str(weight_path)

        model = ExLlama(exllama_config)

        lora_model = None
        if lora_path is not None:
            lora_model = ExLlamaLora(model, str(lora_path / "adapter_config.json"), str(lora_path / "adapter_model.bin"))

        return cls(config, generation_config, exllama_config, model, lora_model)

def main(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)

    tokenizer.padding_side = 'left'

    data = tokenizer(['My name is', 'Nama saya adalah'], padding=True, return_tensors='pt')

    model = ExLlamaForCausalLM.from_pretrained(model_path)

    generate_result = model.generate(**data, max_new_tokens=100)

    print('Batch result:', tokenizer.decode(generate_result[0]))

    data = tokenizer(['My name is'], padding=True, return_tensors='pt')

    generate_result = model.generate(**data, max_new_tokens=100)

    print('Non batch result:', tokenizer.decode(generate_result[0]))

if __name__ == '__main__':
    model_path = input('Your model path: ')
    main(model_path)

The generated result between batch generation and single generation is so big even if I'm using greedy search. There is posibility the reason is because of the position of sequence is not exact for padded sequence