turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.45k stars 257 forks source link

Lora doesn't impact the model outputs #169

Closed matankley closed 8 months ago

matankley commented 9 months ago

@turboderp

I'm running the lora example, but im getting the same results for inference with lora adapters and for inference without adapters.

Here is the code im running


STREAMING=False
MODEL_DIR = "/workspace/cache/Mistral-7B-instruct-exl2/"
LORA_PATH = "<some_local_path>"

import sys, os
from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Tokenizer,
    ExLlamaV2Lora,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2StreamingGenerator,
    ExLlamaV2Sampler
)

import time

# Initialize model and cache

config = ExLlamaV2Config()
config.model_dir = MODEL_DIR
config.prepare()

model = ExLlamaV2(config)
print("Loading model: " + MODEL_DIR)
model.load()

tokenizer = ExLlamaV2Tokenizer(config)

cache = ExLlamaV2Cache(model)

# Load LoRA

lora = ExLlamaV2Lora.from_directory(model, LORA_PATH)

# Initialize generators

streaming_generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
streaming_generator.warmup()
streaming_generator.set_stop_conditions([tokenizer.eos_token_id])

simple_generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

# Sampling settings

settings = ExLlamaV2Sampler.Settings()
settings.temperature = 1
settings.top_k = 1
settings.top_p = 1
settings.token_repetition_penalty = 1.1

# Generate with and without LoRA

def generate_with_lora(prompt_, lora_, max_new_tokens, streaming_ = True):

    print(prompt_, end="")
    sys.stdout.flush()

    if streaming_:

        input_ids = tokenizer.encode(prompt_)

        streaming_generator.begin_stream(input_ids, settings, loras = lora_)
        generated_tokens = 0
        while True:
            chunk, eos, _ = streaming_generator.stream()
            generated_tokens += 1
            print (chunk, end = "")
            sys.stdout.flush()
            if eos or generated_tokens == max_new_tokens: break

        print()

    else:

        output = simple_generator.generate_simple(prompt_, settings, max_new_tokens, loras = lora_)

        print (output[len(prompt_):])
        print()

Can you please assist with that?

turboderp commented 9 months ago

It's hard to say what's happening. When I run the example as is with and without an Alpaca LoRA over the original Llama-7B, here is the output:

--------------------------
No LoRA:

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Write three tweets explaining that the Earth is not flat, using spaghetti and meatballs as an analogy.

### Response:

1. "Earth is round like a spaghetti ball." (Tweet)
2. "No matter how you look at it, the spaghetti ball is round." (Tweet)
3. "You can put this on your plate or eat it straight from the jar." (Tweet)

## External links

* Twitter
* SAT
* SAT 2016: A Revised Scoring Plan

--------------------------
Yes LoRA:

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Write three tweets explaining that the Earth is not flat, using spaghetti and meatballs as an analogy.

### Response:
1. Spaghetti and Meatballs are not flat. Just like the Earth, they are round.
2. The Earth is round too, just like the spaghetti and meatballs. 
3. If you want to learn more about why the Earth is round, try looking up some science articles on Google Scholar.

It's very consistently sticking to the format imposed by the adapter. Without the LoRA it's trying its best to interpret the ### tags and improvising as you'd expect a base model to do.

So it's hard to say what's happening in your case. Mistral-instruct is already very heavily tuned and may have strong preferences that whatever LoRA you're using can't overcome. Do you have a more complete example including the calls to generate_with_lora?