marella / ctransformers

Python bindings for the Transformer models implemented in C/C++ using GGML library.
MIT License
1.79k stars 135 forks source link

Integrating with outlines #91

Open harryjulian opened 1 year ago

harryjulian commented 1 year ago

I'd love to use ctransformers with the outlines library for constrained generation. I opened this issue about it on their repo.

In order to hack away at this integration, the forward pass and pure logit output of the model would need to be exposed in the ctransformers API -- having looked through ctransformers/llm.py it seems as though the .sample method which wraps the .ctransformers_llm_sample method from the C API is the only way to sample tokens from the ctransformers models in it's current state.

Would there be any scope for adding a .forward or .sample_logits method? How would I go about implementing it if there were?

marella commented 1 year ago

You can get the logits using llm.logits property. In order to implement forward, you can use the low-level llm.eval() method. I have done some work on this in the past to make ctransformers a drop-in replacement for 🤗 transformers models (see https://github.com/marella/ctransformers/issues/13#issuecomment-1597662836) Here is the code for reference. Recently I created a better version of it but haven't pushed it yet. Here is a sample code from the newer version:

import torch
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput

class Model(PreTrainedModel):
    def __init__(self, config: PretrainedConfig, llm):
        config.vocab_size = llm.vocab_size
        config.eos_token_id = llm.eos_token_id
        config.pad_token_id = llm.eos_token_id
        super().__init__(config)
        self._llm = llm
        self._past = []

    def prepare_inputs_for_generation(
        self,
        input_ids,
        attention_mask=None,
        **kwargs,
    ):
        return {"input_ids": input_ids}

    def forward(
        self,
        input_ids=None,
        return_dict=None,
        **kwargs,
    ):
        llm = self._llm
        tokens = input_ids.flatten().tolist()
        n_past = len(self._past)
        if tokens[:n_past] == self._past:
            self._past = tokens
            tokens = tokens[n_past:]
        else:
            self._past = tokens
            llm.reset()
        llm.eval(tokens)
        logits = torch.tensor(llm.logits).reshape([1, 1, -1])
        if not return_dict:
            return (logits,)
        return CausalLMOutput(logits=logits)

    @property
    def device(self) -> torch.device:
        return torch.device("cpu")

It can be used as:

from ctransformers import AutoModelForCausalLM

llm = AutoModelForCausalLM.from_pretrained(...)
model = Model(PretrainedConfig(), llm)

But the API is not finalized and may change by the time I release it.

TingTingin commented 1 year ago

Im trying to use ctransformers with outlines i used this code

from ctransformers import AutoModelForCausalLM
from ctransformers.llm import  LLM, get
import torch
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
import outlines.text.generate as generate
import outlines.models as models
from typing import List, Optional, Union

from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.generation import SampleDecoderOnlyOutput
from transformers.generation.streamers import BaseStreamer

class Model(PreTrainedModel):
    def __init__(self, config: PretrainedConfig, llm):
        config.vocab_size = llm.vocab_size
        config.eos_token_id = llm.eos_token_id
        config.pad_token_id = llm.eos_token_id
        super().__init__(config)
        self._llm = llm
        self._past = []

    def prepare_inputs_for_generation(
        self,
        input_ids,
        attention_mask=None,
        **kwargs,
    ):
        return {"input_ids": input_ids}

    def forward(
        self,
        input_ids=None,
        return_dict=None,
        **kwargs,
    ):
        llm = self._llm
        tokens = input_ids.flatten().tolist()
        n_past = len(self._past)
        if tokens[:n_past] == self._past:
            self._past = tokens
            tokens = tokens[n_past:]
        else:
            self._past = tokens
            llm.reset()
        llm.eval(tokens)
        logits = torch.tensor(llm.logits).reshape([1, 1, -1])
        if not return_dict:
            return (logits,)
        return CausalLMOutput(logits=logits)

    @property
    def device(self) -> torch.device:
        return torch.device("cpu")

llm = AutoModelForCausalLM.from_pretrained(r"openorca-platypus2-13b.ggmlv3.q4_K_S.bin",
                                           model_type='llama',
                                           gpu_layers=30,
                                           context_length=4096)
model = Model(PretrainedConfig(), llm)
answer = generate.choice(model, ["Positive", "Negative"])(prompt)

I tried that and got this error AttributeError: 'Model' object has no attribute 'tokenizer'

So i added this:

class Tokenizer:
    def __init__(self, llm: LLM) -> None:
        self._llm = llm
        self.vocab_size = llm.vocab_size
        self.eos_token_id = llm.eos_token_id
        self.eos_token = llm.detokenize(self.eos_token_id) or "</s>"  # TODO
        self.max_sequence_length = llm.context_length
        # self.vocabulary = llm.get_vocabulary()

    def encode(self, text: str) -> List[int]:
        return self._llm.tokenize(text)

    def decode(
        self,
        token_ids: Union[int, List[int], torch.Tensor],
    ) -> str:
        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.tolist()
        return self._llm.detokenize(token_ids)

    def convert_ids_to_tokens(
        self, ids: Union[int, List[int]]
    ) -> Union[str, List[str]]:
        if isinstance(ids, int):
            return self.decode(ids)
        else:
            return [self.decode(id) for id in ids]

    def convert_tokens_to_string(self, tokens: List[str]) -> str:
        return "".join(tokens)

    def convert_tokens_to_ids(
        self, tokens: Union[str, List[str]]
    ) -> Union[int, List[int]]:
        index = 1 if self._llm.model_type == "llama" else 0
        if tokens is None:
            return None
        elif isinstance(tokens, str):
            return self.encode(tokens)[index]
        else:
            return [self.encode(token)[index] for token in tokens]

model.tokenizer = tokenizer

And got this error AttributeError: 'Tokenizer' object has no attribute 'pad_token_id'. Did you mean: 'eos_token_id'?

So i added this:

model.tokenizer.pad_token_id = model.tokenizer.eos_token_id

And got this error: AttributeError: 'Tokenizer' object has no attribute 'vocabulary'

I however can seem to find a way to get the vocabulary of an llm the only method i found was vocab_size not sure how to proceed

marella commented 1 year ago

Looks like outlines generate needs a custom model and tokenizer object which can be created using Transformers and TransformersTokenizer classes. I recommend using the original HF tokenizer to simplify things:

model = Model(PretrainedConfig(), llm) # ctransformers model

from outlines.models.transformers import Transformers, TransformersTokenizer

tokenizer = TransformersTokenizer("Open-Orca/OpenOrca-Platypus2-13B") # change based on the model you are using for ctransformers
model = Transformers(model=model, tokenizer=tokenizer) # outlines model

answer = generate.choice(model, ["Positive", "Negative"])(prompt)
TingTingin commented 1 year ago

thanks seems to be working now for anyone wondering this is the full code

from ctransformers import AutoModelForCausalLM
import torch
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from outlines.models.transformers import Transformers, TransformersTokenizer
import outlines.text.generate as generate

class Model(PreTrainedModel):
    def __init__(self, config: PretrainedConfig, llm):
        config.vocab_size = llm.vocab_size
        config.eos_token_id = llm.eos_token_id
        config.pad_token_id = llm.eos_token_id
        super().__init__(config)
        self._llm = llm
        self._past = []

    def prepare_inputs_for_generation(
        self,
        input_ids,
        attention_mask=None,
        **kwargs,
    ):
        return {"input_ids": input_ids}

    def forward(
        self,
        input_ids=None,
        return_dict=None,
        **kwargs,
    ):
        llm = self._llm
        tokens = input_ids.flatten().tolist()
        n_past = len(self._past)
        if tokens[:n_past] == self._past:
            self._past = tokens
            tokens = tokens[n_past:]
        else:
            self._past = tokens
            llm.reset()
        llm.eval(tokens)
        logits = torch.tensor(llm.logits).reshape([1, 1, -1])
        if not return_dict:
            return (logits,)
        return CausalLMOutput(logits=logits)

    @property
    def device(self) -> torch.device:
        return torch.device("cpu")

llm = AutoModelForCausalLM.from_pretrained(r"C:\Users\Kaman\Downloads\openorca-platypus2-13b.ggmlv3.q4_K_S.bin",
                                           model_type='llama',
                                           gpu_layers=30,
                                           context_length=4096)
model = Model(PretrainedConfig(), llm)
tokenizer = TransformersTokenizer("Open-Orca/OpenOrca-Platypus2-13B")
model = Transformers(model=model, tokenizer=tokenizer)

prompt = """User: You are a sentiment-labeling assistant label this review 
Review: This restaurant was very bad! <|end_of_turn|>
Assistant: 
"""

answer = generate.choice(model, ["Positive", "Negative"])(prompt)
print(answer)
Ozennefr commented 10 months ago

For outlines 0.0.9, the forward pass should return a CausalLMOutputWithPast object instead of a CausalLMOutput. A drop in replacement work although more tweaks should allow for significant speedups.