huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.01k stars 26.3k forks source link

Using multi GPU fails with AutoModelForCausalLM quantization_config=quantization_config #33112

Open FurkanGozukara opened 2 weeks ago

FurkanGozukara commented 2 weeks ago

I am developing an very advanced multi-GPU batch captioning APP

The below code works when I dont use quantization_config=quantization_config because i am able to set .to(device)

but when quantization_config=quantization_config is used it doesn't allow me to set .to(device)

Any ideas ?

When quantization_config=quantization_config is set the error i got is

You shouldn't move a model that is dispatched using accelerate hooks.

Error processing image R:\Joy_Caption_v1\outputs\1_3x_Ultimate_Fidelity_Standard_Texture_2_Creativity_0.png on GPU 0:.tois not supported for4-bitor8-bitbitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correctdtype.

If I remove .to part then the multi GPU part fails :/

The entire code is below it is not very big 460 lines I want to be able to run 8 different captionioning on 8 different GPUs

import gradio as gr
from huggingface_hub import InferenceClient
from torch import nn
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
from pathlib import Path
import torch
import torch.amp.autocast_mode
from PIL import Image, ImageOps
import numpy as np
import io
import os
import argparse
import time
import glob
import platform
from transformers import BitsAndBytesConfig
import re
import threading
from concurrent.futures import ThreadPoolExecutor
import sys

CLIP_PATH = "google/siglip-so400m-patch14-384"
VLM_PROMPT = "A descriptive caption for this image:\n"
MODEL_PATH = "rombodawg/Meta-Llama-3.1-8B-Instruct-reuploaded"
CHECKPOINT_PATH = Path("model_files")
TITLE = "<h1><center>SECourses JoyCaption Image Captioning App V11</center></h1>\n<h2><center>Official Link and Latest Version : <a href='https://www.patreon.com/posts/110613301'>https://www.patreon.com/posts/110613301</a></center></h2>\n"

HF_TOKEN = os.environ.get("HF_TOKEN", None)

class ImageAdapter(nn.Module):
    def __init__(self, input_features: int, output_features: int):
        super().__init__()
        self.linear1 = nn.Linear(input_features, output_features)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(output_features, output_features)

    def forward(self, vision_outputs: torch.Tensor):
        x = self.linear1(vision_outputs)
        x = self.activation(x)
        x = self.linear2(x)
        return x

# Load CLIP
print("Loading CLIP")
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
clip_model = AutoModel.from_pretrained(CLIP_PATH)
clip_model = clip_model.vision_model
clip_model.eval()
clip_model.requires_grad_(False)

# Tokenizer
print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"

class ModelManager:
    def __init__(self, use_4bit):
        self.use_4bit = use_4bit
        self.models = {}
        self.image_adapters = {}
        self.clip_models = {}
        self.default_gpu = 0

    def get_models(self, gpu_id):
        device = f"cuda:{gpu_id}"
        if gpu_id not in self.models:
            print(f"Loading model for GPU {gpu_id}")
            text_model = self.load_model(device)
            text_model.eval()

            image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size).to(device)
            image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location=device))
            image_adapter.eval()

            clip_model_gpu = clip_model.to(device)

            self.models[gpu_id] = text_model
            self.image_adapters[gpu_id] = image_adapter
            self.clip_models[gpu_id] = clip_model_gpu

        return self.models[gpu_id], self.image_adapters[gpu_id], self.clip_models[gpu_id]

    def load_model(self, device):
        if self.use_4bit:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16
            )
        else:
            quantization_config = BitsAndBytesConfig(
                load_in_4bit=False
            )

        model = AutoModelForCausalLM.from_pretrained(
            MODEL_PATH,
            device_map=device,
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16 if not self.use_4bit else None
        )
        return model.to(device)

def open_folder(folder_path):
    if platform.system() == "Windows":
        os.startfile(folder_path)
    elif platform.system() == "Linux":
        os.system(f'xdg-open "{folder_path}"')
    elif platform.system() == "Darwin":  # macOS
        os.system(f'open "{folder_path}"')

def convert_to_png(image_path):
    try:
        with Image.open(image_path) as img:
            if img.mode not in ('RGB', 'RGBA'):
                img = img.convert('RGB')
            elif img.mode == 'RGBA':
                bg = Image.new('RGB', img.size, (255, 255, 255))
                bg.paste(img, mask=img.split()[3])
                img = bg

            img.thumbnail((1024, 1024))
            img_array = np.array(img)

            return img_array
    except Exception as e:
        print(f"Error converting image {image_path}: {str(e)}")
        return None

def cut_off_last_sentence(caption):
    sentence_endings = re.findall(r'[.!?]', caption)
    if sentence_endings:
        last_ending_index = caption.rfind(sentence_endings[-1])
        return caption[:last_ending_index + 1].strip()
    return caption

def remove_this_is(caption):
    caption = caption.replace("|.", "")  # Remove all "|."
    caption = " ".join(caption.split())  # Trim multiple spaces to single space
    caption = caption.strip()  # Trim leading and trailing spaces
    if caption.lower().startswith("this is "):
        caption = caption[8:]  # Remove "this is " from the beginning
    return caption

@torch.no_grad()
def generate_caption(input_image: np.ndarray, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, device, text_model, image_adapter, clip_model, use_4bit):
    torch.cuda.empty_cache()

    image = clip_processor(images=input_image, return_tensors='pt').pixel_values.to(device)

    prompt = tokenizer.encode(VLM_PROMPT, return_tensors='pt', padding=padding, truncation=truncation, add_special_tokens=add_special_tokens).to(device)

    with torch.amp.autocast_mode.autocast(device_type='cuda', enabled=True):
        vision_outputs = clip_model(pixel_values=image, output_hidden_states=True)
        image_features = vision_outputs.hidden_states[-2]
        embedded_images = image_adapter(image_features)

    prompt_embeds = text_model.model.embed_tokens(prompt)
    assert prompt_embeds.shape == (1, prompt.shape[1], text_model.config.hidden_size), f"Prompt shape is {prompt_embeds.shape}, expected {(1, prompt.shape[1], text_model.config.hidden_size)}"
    embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=device, dtype=torch.int64))

    # Determine the target dtype based on 4-bit quantization setting
    target_dtype = torch.float16 if use_4bit else torch.bfloat16

    # Ensure all tensors are on the same device and have the same dtype before concatenation
    embedded_bos = embedded_bos.to(device).to(target_dtype)
    embedded_images = embedded_images.to(device).to(target_dtype)
    prompt_embeds = prompt_embeds.to(device).to(target_dtype)

    inputs_embeds = torch.cat([
        embedded_bos.expand(embedded_images.shape[0], -1, -1),
        embedded_images,
        prompt_embeds.expand(embedded_images.shape[0], -1, -1),
    ], dim=1)

    input_ids = torch.cat([
        torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long, device=device),
        torch.zeros((1, embedded_images.shape[1]), dtype=torch.long, device=device),
        prompt,
    ], dim=1)
    attention_mask = torch.ones_like(input_ids)

    # Ensure inputs_embeds is in the correct dtype
    inputs_embeds = inputs_embeds.to(target_dtype)

    # Use autocast for mixed precision handling
    with torch.cuda.amp.autocast(enabled=use_4bit):
        generate_ids = text_model.generate(
            input_ids,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            top_k=top_k,
            temperature=temperature,
            suppress_tokens=None
        )

    generate_ids = generate_ids[:, input_ids.shape[1]:]
    if generate_ids[0][-1] == tokenizer.eos_token_id:
        generate_ids = generate_ids[:, :-1]

    caption = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    caption = remove_this_is(caption.strip())

    if cut_off_sentence:
        caption = cut_off_last_sentence(caption)

    return caption

def save_caption(caption_path, caption, overwrite, append, remove_newlines):
    if remove_newlines:
        caption = ' '.join(caption.split())

    os.makedirs(os.path.dirname(caption_path), exist_ok=True)

    if os.path.exists(caption_path):
        if overwrite:
            mode = 'w'
        elif append:
            mode = 'a'
        else:
            return None
    else:
        mode = 'w'

    try:
        with open(caption_path, mode, encoding='utf-8') as f:
            if mode == 'a':
                f.write('\n')
            f.write(caption)
        return caption_path
    except Exception as e:
        print(f"Error saving caption: {e}")
        return None

def process_image(image_path, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, output_folder, gpu_id, model_manager):
    os.makedirs(output_folder, exist_ok=True)

    output_image_path = os.path.join(output_folder, os.path.basename(image_path))
    save_path = os.path.splitext(output_image_path)[0] + '.txt'

    if os.path.exists(save_path) and not overwrite and not append:
        print(f"Skipped {image_path} - caption already exists.")
        return None, None, 0, output_image_path

    start_time = time.time()
    print(f"Processing {image_path} on GPU {gpu_id}...")
    try:
        image_array = convert_to_png(image_path)
        if image_array is None:
            raise ValueError("Failed to convert image")

        device = f"cuda:{gpu_id}"
        text_model, image_adapter, clip_model = model_manager.get_models(gpu_id)

        # Ensure all models are on the correct device
        text_model = text_model.to(device)
        image_adapter = image_adapter.to(device)
        clip_model = clip_model.to(device)

        caption = generate_caption(image_array, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, device, text_model, image_adapter, clip_model, model_manager.use_4bit)
        if remove_newlines:
            caption = ' '.join(caption.split())

        actual_save_path = save_caption(save_path, caption, overwrite, append, remove_newlines)

        Image.fromarray(image_array).save(output_image_path)

        process_time = time.time() - start_time
        print(f"Processed {image_path} on GPU {gpu_id} in {process_time:.2f} seconds.")
        return caption, actual_save_path, process_time, output_image_path
    except Exception as e:
        error_message = f"Error processing image {image_path} on GPU {gpu_id}: {str(e)}\n"
        error_message += f"Image shape: {image_array.shape if image_array is not None else 'Unknown'}"
        print(error_message)
        return None, None, 0, image_path

def file_exists(file_path):
    exists = os.path.isfile(file_path)
    size = os.path.getsize(file_path) if exists else 0
    return exists, size

def prepare_gpu_plan(input_folder, output_folder, overwrite, append, gpu_ids):
    image_files = glob.glob(os.path.join(input_folder, '*.*'))
    image_files = [f for f in image_files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'))]

    if not overwrite and not append:
        image_files = [f for f in image_files if not os.path.exists(os.path.join(output_folder, os.path.splitext(os.path.basename(f))[0] + '.txt'))]

    gpu_list = [int(gpu.strip()) for gpu in gpu_ids.split(',') if gpu.strip()]
    num_gpus = len(gpu_list)

    images_per_gpu = len(image_files) // num_gpus
    remainder = len(image_files) % num_gpus

    gpu_plans = []
    start = 0
    for i, gpu in enumerate(gpu_list):
        end = start + images_per_gpu + (1 if i < remainder else 0)
        gpu_plans.append((gpu, image_files[start:end]))
        start = end

    return gpu_plans

def process_gpu_batch(gpu_id, image_files, output_folder, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, model_manager, progress_callback=None):
    total_images = len(image_files)
    processed = 0
    skipped = 0
    start_time = time.time()

    print(f"GPU {gpu_id}: Starting batch processing of {total_images} images.")

    for image_file in image_files:
        result = process_image(image_file, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, output_folder, gpu_id, model_manager)

        if result[0] is None:
            skipped += 1
        else:
            processed += 1

        elapsed = time.time() - start_time
        avg_speed = elapsed / (processed + skipped) * 1000  # in milliseconds
        remaining = total_images - processed - skipped
        estimated_time = remaining * (elapsed / (processed + skipped))

        progress_message = f"GPU {gpu_id} status: {processed} processed, {skipped} skipped, {remaining} left, average speed {avg_speed:.2f} ms/image, estimated time {estimated_time // 60:.0f} min {estimated_time % 60:.0f} seconds\n"
        print(progress_message, end='')
        sys.stdout.flush()

        if progress_callback:
            progress_callback(progress_message)

    print(f"GPU {gpu_id}: Finished processing {processed} images, skipped {skipped} images.")

def multi_gpu_process(gpu_plans, output_folder, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, model_manager, progress_callback=None):
    with ThreadPoolExecutor(max_workers=len(gpu_plans)) as executor:
        futures = []
        for gpu_id, image_files in gpu_plans:
            future = executor.submit(process_gpu_batch, gpu_id, image_files, output_folder, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, model_manager, progress_callback)
            futures.append(future)

        for future in futures:
            future.result()

def batch_process(input_folder, output_folder, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, gpu_ids, model_manager, progress_callback=None):
    if not output_folder:
        output_folder = input_folder

    gpu_plans = prepare_gpu_plan(input_folder, output_folder, overwrite, append, gpu_ids)
    total_images = sum(len(plan[1]) for plan in gpu_plans)

    progress = f"Starting batch processing. Found {total_images} images.\n"
    progress += f"Input folder: {input_folder}\n"
    progress += f"Output folder: {output_folder}\n"
    print(progress)
    if progress_callback:
        progress_callback(progress)

    multi_gpu_process(gpu_plans, output_folder, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, model_manager, progress_callback)

    final_message = f"Batch processing complete. Processed {total_images} images.\n"
    print(final_message)
    if progress_callback:
        progress_callback(final_message)
    return progress + final_message

def gradio_interface(image_input, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, model_manager):
    if isinstance(image_input, dict):
        image_path = image_input['path']
    else:
        image_path = image_input

    gpu_id = model_manager.default_gpu  # Use the default GPU for single image processing
    output_folder = "outputs"
    output_image_path = os.path.join(output_folder, os.path.basename(image_path))
    save_path = os.path.splitext(output_image_path)[0] + '.txt'

    # Check if the caption file already exists
    if os.path.exists(save_path) and not overwrite and not append:
        return "Skipped", f"Skipped captioning for {image_path} - caption file already exists and neither overwrite nor append was selected.\nProcessing time: 0 seconds"

    caption, actual_save_path, process_time, output_image_path = process_image(image_path, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, output_folder, gpu_id, model_manager)

    if caption is None:
        save_info = f"Failed to generate caption for {image_path}."
        processing_info = f"Processing time: {process_time:.2f} seconds"
    else:
        if actual_save_path:
            save_info = f"Caption saved to: {actual_save_path}"
        else:
            save_info = f"Caption could not be saved. Attempted path: {os.path.splitext(output_image_path)[0] + '.txt'}"

        processing_info = f"Processing time: {process_time:.2f} seconds"

    return caption or "Failed", f"{save_info}\n{processing_info}"

def main():
    parser = argparse.ArgumentParser(description="SECourses JoyCaption Image Captioning App V11")
    parser.add_argument("--share", action="store_true", help="Use Gradio's share feature")
    parser.add_argument("--use_4bit", action="store_true", help="Use 4-bit quantization")
    global args
    args = parser.parse_args()

    model_manager = ModelManager(args.use_4bit)

    with gr.Blocks() as demo:
        gr.HTML(TITLE)
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(type="filepath", label="Input Image", height=512)
                overwrite = gr.Checkbox(label="Overwrite existing caption file", value=False)
                append = gr.Checkbox(label="Append new caption to existing caption", value=False)
                remove_newlines = gr.Checkbox(label="Remove newlines from generated captions", value=True)
                cut_off_sentence = gr.Checkbox(label="Cut off at last complete sentence", value=True)
                run_button = gr.Button("Caption")

            with gr.Column():
                output_caption = gr.Textbox(label="Caption")
                save_info = gr.Textbox(label="Save Information and Processing Time")
                max_new_tokens = gr.Slider(minimum=1, maximum=1000, value=300, step=1, label="Max New Tokens")
                do_sample = gr.Checkbox(label="Do Sample", value=False)
                top_k = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Top K")
                temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.5, step=0.1, label="Temperature")
                padding = gr.Checkbox(label="Padding", value=False)
                truncation = gr.Checkbox(label="Truncation", value=False)
                add_special_tokens = gr.Checkbox(label="Add Special Tokens", value=False)

        with gr.Row():
            input_folder = gr.Textbox(label="Input Folder for Batch Processing")
            output_folder = gr.Textbox(label="Output Folder for Batch Processing (Optional)")
            gpu_ids = gr.Textbox(label="GPU IDs (comma-separated, e.g., 0,1,2)", value="0")
            batch_button = gr.Button("Start Batch Processing")

        batch_progress = gr.Textbox(label="Batch Processing Progress", lines=20)

        open_outputs_button = gr.Button("Open Results Folder")

        run_button.click(fn=lambda *args: gradio_interface(*args, model_manager), 
                         inputs=[input_image, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence], 
                         outputs=[output_caption, save_info])

        def batch_process_with_progress(*args):
            progress = ""
            def update_progress(msg):
                nonlocal progress
                progress += msg
                return progress

            return batch_process(*args, model_manager=model_manager, progress_callback=update_progress)

        batch_button.click(fn=batch_process_with_progress, 
                           inputs=[input_folder, output_folder, overwrite, append, remove_newlines, max_new_tokens, do_sample, top_k, temperature, padding, truncation, add_special_tokens, cut_off_sentence, gpu_ids],
                           outputs=batch_progress)

        open_outputs_button.click(fn=lambda: open_folder(output_folder.value or input_folder.value or "outputs"))

    demo.launch(share=args.share, inbrowser=True)

if __name__ == "__main__":
    main()

Who can help?

@SunMarc @Narsil

SunMarc commented 2 weeks ago

Hi @FurkanGozukara, thanks for the report ! I've opened this draft PR for you to try ! This should have been solved with bnb>=0.43.0 with 4-bit models but it wasn't upstreamed to transformers. Let me know how it goes ! https://github.com/huggingface/transformers/pull/33122

FurkanGozukara commented 2 weeks ago

Hi @FurkanGozukara, thanks for the report ! I've opened this draft PR for you to try ! This should have been solved with bnb>0.43.0 with 4-bit models but it wasn't upstreamed to transformers. Let me know how it goes ! #33122

awesome i will test later

i made a different implementation and made it work with 4bit :D it was really hard