ceruleandeep / ComfyUI-LLaVA-Captioner

A ComfyUI extension for chatting with your images with LLaVA. Runs locally, no external services, no filter.
GNU General Public License v3.0
110 stars 11 forks source link

Feature Submission: Seed selection. #17

Open BlastedRemnants opened 2 months ago

BlastedRemnants commented 2 months ago

Is there any way to select a seed, or gain a similar effect? Sometimes I want to make the captioner try again, but to do so I need to change a setting, run it, change the setting back, then run it again. This does work, which suggests to me that there must be some sort of seed or something similar being used, but it's inaccessible to us at the moment.

I was about to post this Feature Request but decided to take a run at implementing it myself with the help of CoPilot first, and I've got a working solution now. If anyone would like to test this and make sure it works for everyone else as well that'd be great, and maybe it will get adopted into the default scripts for this node. Below is a paste of my edited llava.py file. Cheers!

 # parts of this looted from https://github.com/pythongosssss/ComfyUI-WD14-Tagger
import asyncio
import base64
import os
import re
import time
from io import BytesIO

import numpy as np
import torch
from PIL import Image
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler

import comfy.utils
import folder_paths

model_fmt = ".gguf"
model_type = "llama"
system_message = (
    "You are an assistant who describes the content and composition of images. "
    "Describe only what you see in the image, not what you think the image is about. "
    "Be factual and literal. Do not use metaphors or similes. Be concise."
)

defaults = {
    "model": "llava-v1.5-7b-Q4_K",
    "mmproj": "llava-v1.5-7b-mmproj-Q4_0",
    "temperature": 0.2,
    "max_tokens": 40,
    "prompt": "Please describe this image in 10 to 20 words.",
    "n_gpu_layers": -1,
    "seed": 42,
}

def get_ext_dir(subpath=None, mkdir=False):
    dir = os.path.dirname(__file__)
    if subpath is not None:
        dir = os.path.join(dir, subpath)

    dir = os.path.abspath(dir)

    if mkdir and not os.path.exists(dir):
        os.makedirs(dir)
    return dir

def get_installed_models(mm_proj=False):
    if model_type not in folder_paths.folder_names_and_paths:
        models_dir = get_ext_dir("models", mkdir=True)
        folder_paths.add_model_folder_path(model_type, models_dir)

    models = folder_paths.get_filename_list(model_type)
    return [
        re.sub(rf"{model_fmt}$", "", m)
        for m in models
        if m.endswith(model_fmt) and ("mmproj" in m) == mm_proj
    ]

async def get_llava(
    model,
    mm_proj,
    n_gpu_layers=0,
):
    if n_gpu_layers is None:
        n_gpu_layers = 0

    assert isinstance(model, str), f"{model} {type(model)=}"
    assert isinstance(mm_proj, str), f"{mm_proj} {type(mm_proj)=}"
    assert isinstance(n_gpu_layers, int), f"{n_gpu_layers} {type(n_gpu_layers)=}"

    model_path = folder_paths.get_full_path(model_type, model + model_fmt)
    mmproj_path = folder_paths.get_full_path(model_type, mm_proj + model_fmt)

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model {model_path} does not exist")

    if not os.path.exists(mmproj_path):
        raise FileNotFoundError(f"Model {mmproj_path} does not exist")

    chat_handler = Llava15ChatHandler(clip_model_path=mmproj_path)

    start = time.monotonic()

    # noinspection PyTypeChecker
    llm = Llama(
        model_path=model_path,
        n_gpu_layers=n_gpu_layers,
        chat_format="llava-1-5",
        chat_handler=chat_handler,
        n_ctx=2048,  # n_ctx should be increased to accomodate the image embedding
        logits_all=True,
        verbose=False,
    )
    print(f"LLM loaded in {time.monotonic() - start:.1f}s")
    return llm

def encode(image: Image.Image):
    assert isinstance(image, Image.Image), f"{image} {type(image)}"
    with BytesIO() as output:
        image.save(output, format="PNG")
        image_bytes = output.getvalue()
    base64_image = base64.b64encode(image_bytes).decode("utf-8")
    image_url = f"data:image/png;base64,{base64_image}"
    return image_url

async def get_caption(
    llm: Llama,
    image: Image.Image,
    prompt,
    temp,
    max_tokens=35,
    seed=42,  # Add the seed parameter with a default value
):
    assert isinstance(image, Image.Image), f"{image} {type(image)=}"
    assert isinstance(system_message, str), f"{system_message} {type(system_message)=}"
    assert isinstance(prompt, str), f"{prompt} {type(prompt)=}"
    assert isinstance(temp, float), f"{temp} {type(temp)=}"
    assert isinstance(max_tokens, int), f"{max_tokens} {type(max_tokens)=}"
    assert isinstance(seed, int), f"{seed} {type(seed)=}"  # Add this assertion

    file_url = encode(image)
    messages = [
        {"role": "system", "content": system_message},
        {
            "role": "user",
            "content": [
                {"type": "image_url", "image_url": {"url": file_url}},
                {"type": "text", "text": prompt},
            ],
        },
    ]

    start = time.monotonic()
    response = llm.create_chat_completion(
        messages=messages,
        temperature=temp,
        max_tokens=max_tokens,
        seed=seed  # Pass the seed to the LLM function
    )
    print(f"Response in {time.monotonic() - start:.1f}s")

    first_resp: dict = response["choices"][0]
    content = first_resp["message"]["content"]

    return content.strip()

def wait_for_async(async_fn, loop=None):
    res = []

    async def run_async():
        r = await async_fn()
        res.append(r)

    if loop is None:
        try:
            loop = asyncio.get_event_loop()
        except:
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

    loop.run_until_complete(run_async())

    return res[0]

class LlavaCaptioner:
    @classmethod
    def INPUT_TYPES(s):
        all_models = get_installed_models()
        all_mmproj = get_installed_models(mm_proj=True)

        return {
            "required": {
                "image": ("IMAGE",),
                "model": (all_models,),
                "mm_proj": (all_mmproj,),
                "prompt": (
                    "STRING",
                    {"default": defaults["prompt"], "multiline": True},
                ),
                "max_tokens": (
                    "INT",
                    {
                        "default": defaults["max_tokens"],
                        "min": 0,
                        "max": 200,
                        "step": 5,
                    },
                ),
                "temperature": (
                    "FLOAT",
                    {
                        "default": defaults["temperature"],
                        "min": 0.0,
                        "max": 1,
                        "step": 0.1,
                    },
                ),
                "seed": (
                    "INT",
                    {
                        "default": 42,
                        "min": 0,
                        "max": 1000000,
                    },
                ),
            }
        }

    RETURN_TYPES = ("STRING",)
    OUTPUT_IS_LIST = (False,)
    FUNCTION = "caption"
    OUTPUT_NODE = True

    CATEGORY = "image"

    def caption(self, image, model, mm_proj, prompt, max_tokens, temperature, seed):  # Add seed here
        assert isinstance(image, torch.Tensor), f"{image} {type(image)=}"
        assert isinstance(model, str), f"{model} {type(model)=}"
        assert isinstance(mm_proj, str), f"{mm_proj} {type(mm_proj)=}"
        assert isinstance(prompt, str), f"{prompt} {type(prompt)=}"
        assert isinstance(max_tokens, int), f"{max_tokens} {type(max_tokens)=}"
        assert isinstance(temperature, float), f"{temperature} {type(temperature)=}"
        assert isinstance(seed, int), f"{seed} {type(seed)=}"  # Add this assertion

        tensor = image * 255
        tensor = np.array(tensor, dtype=np.uint8)

        pbar = comfy.utils.ProgressBar(tensor.shape[0] + 1)

        llava = wait_for_async(lambda: get_llava(model, mm_proj, -1))
        pbar.update(1)

        tags = []
        for i in range(tensor.shape[0]):
            image = Image.fromarray(tensor[i])
            tags.append(
                wait_for_async(
                    lambda: get_caption(
                        llava,
                        image,
                        prompt,
                        temperature,
                        max_tokens,
                        seed,  # Pass seed here
                    )
                )
            )
            pbar.update(1)
        result = "\n".join(tags)
        return {"ui": {"tags": tags}, "result": (result,)}

NODE_CLASS_MAPPINGS = {
    "LlavaCaptioner": LlavaCaptioner,
}
NODE_DISPLAY_NAME_MAPPINGS = {
    "LlavaCaptioner": "LLaVA Captioner 🌊",
}
BlastedRemnants commented 2 months ago

Sorry by the way, I thought this would ask me to add Labels and whatnot at some point but I must have missed that somehow.

Update: I submitted a Pull Request with the alterations to this file in it. I don't know what I'm doing with Git though so if I did it wrong I apologize, and also feel free to close this or delete it if that's the proper thing to do now since the issue is basically resolved. Thanks!