Vahe1994 / AQLM

Official Pytorch repository for Extreme Compression of Large Language Models via Additive Quantization https://arxiv.org/pdf/2401.06118.pdf and PV-Tuning: Beyond Straight-Through Estimation for Extreme LLM Compression https://arxiv.org/abs/2405.14852
Apache License 2.0
1.16k stars 175 forks source link

How to import and use it in my existent code that loads LLMs? #109

Closed Kuchiriel closed 2 months ago

Kuchiriel commented 4 months ago

This is the code, I achived 4bits with normal libs

import gc import os import re import torch import tensorflow as tf import pandas as pd import matplotlib.pyplot as plt import nltk import fitz # PyMuPDF import numpy as np import scipy import tiktoken from sentence_transformers import SentenceTransformer from scipy.stats import fisher_exact from torch import nn from transformers import ( GenerationConfig, TextStreamer, BitsAndBytesConfig, GPT2TokenizerFast, AutoConfig, AutoModelForCausalLM, AutoTokenizer ) from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS from IPython.display import display import ipywidgets as widgets from accelerate import Accelerator, init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map from nltk.corpus import stopwords

Download necessary NLTK data

nltk.download("stopwords")

Display library versions

print(f"NumPy version: {np.version}") print(f"SciPy version: {scipy.version}")

Perform a Fisher's exact test as an example

result = fisher_exact([[10, 10], [5, 20]]) print("Fisher's exact test result:", result)

Model and embedding configurations

CONTENT = "You are a friendly chatbot who always responds to user instructions precisely and remembers the previous prompts" conversation_history = [] MAX_TOKENS_PER_MB = 20

model_names = [ "TheBloke/zephyr-7B-beta-GPTQ", "TheBloke/zephyr-7B-beta-GGUF", "HuggingFaceH4/zephyr-7b-beta" ] CPU_MODEL = "MaziyarPanahi/Phi-3-mini-4k-instruct-v0.3" EMBEDDING_MODEL_NAME = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"

os.environ["TOKENIZERS_PARALLELISM"] = "false"

Open and read PDF

doc = fitz.open("/kaggle/input/attention-is-all-you-need/attention_is_all_you_need.pdf") text = "" for page in doc: text += page.get_text()

Tokenize the text

enc = tiktoken.encoding_for_model("gpt2")

def count_tokens(text: str) -> int: return len(enc.encode(text))

text_splitter = RecursiveCharacterTextSplitter( chunk_size=512, chunk_overlap=24, length_function=count_tokens, )

chunks = text_splitter.create_documents([text])

Plot token counts

token_counts = [count_tokens(chunk.page_content) for chunk in chunks] df = pd.DataFrame({"Token Count": token_counts}) df.hist(bins=40) plt.show()

Sentence transformer wrapper class

class SentenceTransformerWrapper: def init(self, model_name): self.model = SentenceTransformer(model_name)

def embed_documents(self, texts):
    return self.model.encode(texts)

def embed_query(self, text):
    return self.model.encode([text])[0]

def __call__(self, text):
    return self.embed_query(text)

embedding_model = SentenceTransformerWrapper(EMBEDDING_MODEL_NAME)

Create FAISS database

db = FAISS.from_documents(chunks, embedding_model)

Initialize accelerator

accelerator = Accelerator()

def get_available_memory(): if DEVICE == "cuda": torch.cuda.empty_cache() total_memory = torch.cuda.get_device_properties(0).total_memory available_memory = total_memory - torch.cuda.memory_allocated() return available_memory return None

def set_dynamic_memory_allocation(max_size_mb=None): total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**2) print("Total memory:", total_memory) available_memory = total_memory

if max_size_mb is None:
    max_size_mb = min(128, int(available_memory / 4))

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:{max_size_mb}"
print(f"Setting max_split_size_mb to {max_size_mb} MB")

def test_is_tpu_available(): devices = tf.config.list_logical_devices() for device in devices: if device.device_type == "TPU": return True return False

Check for available devices

if torch.cuda.is_available(): DEVICE = "cuda" print("GPU Available:", torch.cuda.get_device_name(torch.cuda.current_device())) elif test_is_tpu_available(): try: tpu = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(tpu) tf.tpu.experimental.initialize_tpu_system(tpu) print(f"Running on TPU: {tpu.cluster_spec().as_dict()['worker']}") strategy = tf.distribute.experimental.TPUStrategy(tpu) DEVICE = "tpu" except tf.errors.AlreadyExistsError: print("TPU already initialized") except tf.errors.FailedPreconditionError as e: print(f"Failed to initialize TPU: {e}") DEVICE = "cpu" else: DEVICE = "cpu" print("Device:", DEVICE)

Set data types based on device

if DEVICE == "cuda" and torch.cuda.is_bf16_supported(): DTYPE = torch.bfloat16 elif DEVICE == "cuda" and not torch.cuda.is_bf16_supported(): DTYPE = torch.float16 else: DTYPE = torch.float32 print("DTYPE:", DTYPE)

def flush(): if DEVICE in ["cuda", "tpu"]: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() accelerator.free_memory() gc.collect()

def clear_all(): locals_toremove = [var for var in locals() if var[0] != ""] for var in locals_to_remove: del locals()[var] globals_toremove = [var for var in globals() if var[0] != ""] for var in globals_to_remove: del globals()[var] flush()

def print_memory_usage(): if DEVICE == "cuda": allocated = torch.cuda.memory_allocated() / 1e9 max_allocated = torch.cuda.max_memory_allocated() / 1e9 print(f"Memory Allocated: {allocated} GB, Max Allocated: {max_allocated} GB") flush()

def inspect_model_bits(model): for name, param in model.named_parameters(): print(f"Parameter: {name}, dtype: {param.dtype}")

def load_gptq_model(model_name): from auto_gptq import AutoGPTQForCausalLM config = BitsAndBytesConfig( load_in_4bit=True, disable_exllama=True, ) model = AutoGPTQForCausalLM.from_pretrained(model_name, quantization_config=config) return model

def load_gguf_model(model_name): model = AutoModelForCausalLM.from_pretrained(model_name, model_type="mistral") return model

def determine_model_type(model): try: if hasattr(model, 'gptq_config'): return "gptq" elif hasattr(model.config, 'quantization_approach') and model.config.quantization_approach == 'GGUF': return "gguf" else: return "unknown" except Exception as e: print(f"Error determining model type: {e}") return "unknown"

def load_model_by_type(model_type, model_name): try: if model_type == "gptq": return load_gptq_model(model_name) elif model_type == "gguf": return load_gguf_model(model_name) else: raise ValueError(f"Unknown model type: {model_type}") except Exception as e: print(f"Failed to load {model_name}: {e}") return None

def setup_model_and_tokenizer(model_name): torch.set_grad_enabled(False) try: tokenizer = AutoTokenizer.from_pretrained( model_name, use_fast=True, legacy=True, trust_remote_code=True if DEVICE == "cpu" else False, ) model = AutoModelForCausalLM.from_pretrained(model_name) model_type = determine_model_type(model) print(f"Detected model type: {model_type}") if model_type == "unknown": raise ValueError(f"Unknown model type for {model_name}") model = load_model_by_type(model_type, model_name) if model is None: raise ValueError(f"Failed to load model {model_name}")

    tokenizer.save_pretrained("/kaggle/working")

    accelerator.save_model(model, save_directory="/kaggle/working", max_shard_size="GB")

    device_map = infer_auto_device_map(model, max_memory={0: "5GiB", 1: "5GiB", "cpu": "30GiB"})

    model = load_checkpoint_and_dispatch(
        model, model_name, device_map=device_map, no_split_module_classes=["GPTJBlock"]
    )

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=DTYPE,
    )

    model = accelerator.prepare(model)
    model.eval()

    if tokenizer.sep_token is None:
        tokenizer.sep_token = "[SEP]"

    if tokenizer.cls_token is None:
        tokenizer.cls_token = "[CLS]"

    if tokenizer.mask_token is None:
        tokenizer.mask_token = "[MASK]"

    model.config.update(
        {
            "load_in_4bit": True,
            "quantization_config": bnb_config,
            "device_map": "Auto",
            "attn_implementation": "flash_attention_2",
            "pad_token_id": tokenizer.pad_token_id,
            "eos_token_id": tokenizer.eos_token_id,
            "bos_token_id": tokenizer.bos_token_id,
            "unk_token_id": tokenizer.unk_token_id,
            "sep_token_id": tokenizer.sep_token_id,
            "cls_token_id": tokenizer.cls_token_id,
            "mask_token_id": tokenizer.mask_token_id,
        }
    )

    inspect_model_bits(model)

    if torch.cuda.device_count() > 1:
        print(f"{torch.cuda.device_count()} GPUs detected, initializing Data Parallel...")
        model = nn.DataParallel(model)
        model.to(DTYPE)

        underlying_model = model.module if isinstance(model, torch.nn.DataParallel) else model

        if underlying_model.config.vocab_size != len(tokenizer):
            underlying_model.resize_token_embeddings(len(tokenizer))
    elif torch.cuda.device_count() < 1:
        model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=DTYPE)

    generation_config = GenerationConfig(
        do_sample=True,
        temperature=0.9,
        min_length=50,
        max_length=300,  # Reduced max length
        bnb_4bit_compute_dtype=DTYPE,
        penalty_alpha=0.6,
        repetition_penalty=1.1,
        top_k=15,
        top_p=0.95,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bos_token_id=tokenizer.bos_token_id,
        unk_token_id=tokenizer.unk_token_id,
        sep_token_id=tokenizer.sep_token_id,
        cls_token_id=tokenizer.cls_token_id,
        mask_token_id=tokenizer.mask_token_id,
    )

    return model, tokenizer, generation_config
except Exception as e:
    print(f"Failed to setup model and tokenizer for {model_name}: {e}")
    return None, None, None

if DEVICE in ["cuda", "tpu"]: for model_name in model_names: print(f"\nTesting model: {model_name}") try: model, tokenizer, generation_config = setup_model_and_tokenizer(model_name) if model is not None: print(f"Successfully loaded {model_name}") else: print(f"Failed to load {model_name}") except Exception as e: print(f"Failed to load {model_name}: {e}") else: MODEL = CPU_MODEL

def truncate_history_based_on_memory(): available_memory = get_available_memory() if available_memory is not None: max_tokens = int(available_memory / (1024*2)) MAX_TOKENS_PER_MB current_tokens = sum( len(tokenizer.encode(exchange["content"])) for exchange in conversation_history )

    while current_tokens > max_tokens:
        conversation_history.pop(0)
        current_tokens = sum(
            len(tokenizer.encode(exchange["content"]))
            for exchange in conversation_history
        )

def determine_max_new_tokens(input_length, max_model_input_size, prompt): complexity_factor = 1.0 if is_complex(prompt): complexity_factor = 1.5

if input_length > max_model_input_size // 2:
    return int((max_model_input_size // 4) * complexity_factor)
else:
    return int((max_model_input_size // 2) * complexity_factor)

def is_complex(prompt): word_count_threshold = 12 unique_word_threshold = 10 long_word_threshold = 7 complex_sentence_threshold = 2

words = prompt.split()

if len(words) > word_count_threshold:
    return True

unique_words = set(words)
if len(unique_words) > unique_word_threshold:
    return True

long_words = [word for word in words if len(word) >= long_word_threshold]
if len(long_words) > long_word_threshold:
    return True

sentences = re.split(r"[.!?]+", prompt)
complex_sentences = [
    sentence
    for sentence in sentences
    if len(sentence.split()) > word_count_threshold
]
if len(complex_sentences) > complex_sentence_threshold:
    return True

stop_words = set(stopwords.words("english"))
non_stop_words = [word for word in unique_words if word.lower() not in stop_words]
if len(non_stop_words) > unique_word_threshold:
    return True

return False

def estimate_memory_per_token(model): return 5

def calculate_dynamic_max_length(model, buffer_factor=0.8): memory_per_token = estimate_memory_per_token(model) available_memory = get_available_memory()

if available_memory is None:
    return 512

dynamic_max_length = int((available_memory * buffer_factor) / memory_per_token)
return dynamic_max_length

def segment_input(user_prompt, max_length): words = user_prompt.split() segments = [] current_segment = []

for word in words:
    if len(" ".join(current_segment + [word])) > max_length:
        segments.append(" ".join(current_segment))
        current_segment = [word]
    else:
        current_segment.append(word)

if current_segment:
    segments.append(" ".join(current_segment))

return segments

def validate_tensor(tensor, name): if torch.is_tensor(tensor): if torch.isnan(tensor).any(): raise ValueError(f"{name} contains NaN values") if torch.isinf(tensor).any(): raise ValueError(f"{name} contains inf values") if (tensor < 0).any(): raise ValueError(f"{name} contains values less than 0") else: raise TypeError(f"{name} is not a tensor")

def stream_text(segment, model, tokenizer, generation_config, context): global conversation_history global CONTENT

if isinstance(model, torch.nn.DataParallel):
    model = model.module

conversation_history.append({"role": "user", "content": segment})

conversation_context = [{"role": "system", "content": CONTENT}] + conversation_history

if context:
    conversation_context.append({"role": "user", "content": context})

conversation_context = [
    {"role": "user" if i % 2 == 0 else "assistant", "content": entry["content"]}
    for i, entry in enumerate(conversation_context)
]

formatted_context = [entry["content"] for entry in conversation_context]

model_type = determine_model_type(model)

if model_type == "gguf":
    formatted_context = tokenizer.chat_template = AutoTokenizer.from_pretrained("TheBloke/zephyr-7B-beta-GGUF").chat_template
else:
    formatted_context = tokenizer.apply_chat_template(
        conversation_context, tokenize=False, return_tensors="pt"
    )

input_ids = tokenizer.encode(formatted_context, return_tensors="pt")
validate_tensor(input_ids, "input_ids")
input_length = input_ids.size(1)
max_model_input_size = model.config.max_position_embeddings
if input_length > max_model_input_size:
    input_ids = input_ids[:, -max_model_input_size:]
    input_length = max_model_input_size
input_ids = input_ids.to(model.device)

dynamic_max_length = calculate_dynamic_max_length(model)
max_new_tokens = determine_max_new_tokens(input_length, max_model_input_size, segment)
max_new_tokens = min(max_new_tokens, dynamic_max_length - input_length)

attention_mask = torch.ones_like(input_ids)

streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

generated_tokens = model.generate(
    input_ids,
    attention_mask=attention_mask,
    min_length=generation_config.min_length,
    max_length=input_length + max_new_tokens,
    generation_config=generation_config,
    streamer=streamer,
)

validate_tensor(generated_tokens, "generated_tokens")

output_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
conversation_history.append({"role": "system", "content": output_text})

truncate_history_based_on_memory()

return output_text

def call_llm(): try: user_prompt = input("Enter your prompt (or type 'exit' to quit): ") if user_prompt.lower() == "exit": return "exit" if DEVICE == "cuda": with torch.inference_mode(): max_length = calculate_dynamic_max_length(model) segments = segment_input(user_prompt, max_length) responses = [] for segment in segments: docs = db.similarity_search(segment) context = " ".join([doc.page_content for doc in docs]) response = stream_text(segment, model, tokenizer, generation_config, context) responses.append(response) for response in responses: display(widgets.HTML(f'Chatbot: {response}')) flush() else: print("CPU inference not implemented yet") except Exception as ex: print(f"An error occurred during text generation: {ex}") finally: if DEVICE == "cuda": print_memory_usage() else: accelerator.free_memory() return user_prompt

def main(): try: while True: if call_llm() == "exit": break except KeyboardInterrupt: print("\nExiting the program.") print_memory_usage() clear_all()

Set memory allocation settings

set_dynamic_memory_allocation()

if name == "main": main()

justheuristic commented 3 months ago

Hello! Unfortunately, we do not have the bandwidth to directly write this code, but we can give you some pointers.

If you mean to ask "how to load a pre-quantized AQLM model" (with or without PV-tuning), please refer to this guide: https://huggingface.co/docs/transformers/main/en/quantization/aqlm . You can find a list of pre-quantized models in or readme or on the hub.

If you want to quantize ("train") an AQLM quantization for an arbitrary model inside the code, we would not recommend that: our calibration code needs to be run in a very particular way (see instructions from README and scroll down for fine-tuning) . If you absolutely must train the model from your script, use subprocess to invoke the script as per the instructions above. In principle, you can also manually merge the codebases, but this will require a lot of (careful) work, where any bug can silently ruin model quality.

Unfortunately, that is all we can tell you atm.

p.s. if you need to share a large snippet of code on github, you may find it more convenient to (a) surround it with triple backticks or (b) publish it as a github gist and linking it to your issue. This is a general recommendation, not formatting the code for this particular issue does not inconvenience anyone since we do not have the bandwidth to help anyway. But if you ask around on other discussion boards, others may be better able to help if you provide them with a minimal python-formatted snippet and an explanation what it does, how you tried to add AQLM and what errors you encountered.

github-actions[bot] commented 2 months ago

This issue is stale because it has been open for 30 days with no activity.

github-actions[bot] commented 2 months ago

This issue was closed because it has been inactive for 14 days since being marked as stale.