RunpeiDong / DreamLLM

[ICLR 2024 Spotlight] DreamLLM: Synergistic Multimodal Comprehension and Creation
https://dreamllm.github.io/
Apache License 2.0
377 stars 5 forks source link

How can I simultaneously generate interleaved image and text content? #15

Open Y-aang opened 4 months ago

Y-aang commented 4 months ago

Hello,

This work is truly inspiring and brilliant!

Thank you very much for making it open source. I noticed that the inference script in the repository can only generate images based on text. Have you written a script in the past that generates interleaved graph-text content? It would be very kind of you to provide it if possible.

RunpeiDong commented 4 months ago

Hi @Y-aang,

Thank you for your interest! I just came back from Austria. I will try to find some time to write this script this month.

Y-aang commented 4 months ago

Dear Ruipeng,

Thank you for your reply. I hope you had a great time, and your efforts and contributions to DreamLLM are greatly appreciated. Wishing you all the best!

Regards, Y-aang

2024年5月14日 17:23,Runpei Dong @.***> 写道:

Hi @Y-aang https://github.com/Y-aang,

Thank you for your interest! I just came back from Austria. I will try to find some time to write this script this month.

— Reply to this email directly, view it on GitHub https://github.com/RunpeiDong/DreamLLM/issues/15#issuecomment-2109706164, or unsubscribe https://github.com/notifications/unsubscribe-auth/AWPQEVLZ4TRRENEL6Y43ANLZCHJXLAVCNFSM6AAAAABHLRYLL6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBZG4YDMMJWGQ. You are receiving this because you were mentioned.

ZihaoLin0123 commented 3 months ago

Hi @RunpeiDong ,

Thanks so much for your excellent work. I wonder would you please share the code for interleaved generation? Thanks!

yeongjoonJu commented 3 months ago

Thank you for your awesome work! I'm looking forward to sharing the code for interleaved generation!

YiFang99 commented 1 month ago

Hi, is the module coming?

RunpeiDong commented 1 month ago

Hi all, sorry for the delay. But I still don't have the time to implement this. However, I can show you guys one old script from the earlier version of the codebase, which is a streamlit script and was used for internal development. This script cannot be run now since the API has changed. But I hope you can try to read the code and see the logits, which is very clear how interleaved contents are generated.

Here is the script:

# ------------------------------------------------------------------------------------------------
# Copyright (c) 2023 DreamLLM Authors. All rights reserved.
# Designed by Runpei Dong
# Last updated: 03/2023
# ------------------------------------------------------------------------------------------------

__version__ = "0.0.0"
app_name = "DreamLLM"

import os
import re
import html

from PIL import Image
import streamlit as st

import torch
import argparse
from transformers import LlamaTokenizer

from dreamllm.utils.constants import *
from dreamllm.utils.conversation import conv_templates, SeparatorStyle
from dreamllm.model.llama.modeling_dreamllm_llama_stablediffusion import DreamLLMForCausalLM as AutoModelForCausalLM

from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel

DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_DREAM_START_TOKEN = "<dream_start>"  # NOTE make llm dream!
DEFAULT_DREAM_END_TOKEN = "<dream_end>"  # NOTE make llm dream!

MAX_INPUTS = 1
chat_history = []

def demo_title(page_title="DreamLLM", page_icon=":clown_face:"):
    st.set_page_config(page_title=page_title, page_icon=page_icon, layout="wide")
    st.markdown("<h1 style='text-align: center;'> 🐑 DreamLLM: Dream, See, and Talk</h1>", unsafe_allow_html=True)
    st.markdown("<h5 style='text-align: center;'> <I>\"Do Androids Dream of Electric Sheep?\"</I> —— Philip K. Dick", unsafe_allow_html=True)
    st.markdown("<h5 style='text-align: center; color: rgba(0, 0, 0, 0.60);'> Foundation Model Department, MEGVII Technology </h5>", unsafe_allow_html=True)

def demo_spacer(n=2, line=False, next_n=0):
    for _ in range(n):
        st.write("")
    if line:
        st.tabs([" "])
    for _ in range(next_n):
        st.write("")

def build_diffusion_blocks():
    # diffusion_model_cfg = "runwayml/stable-diffusion-v1-5"
    diffusion_model_cfg = "/data/model_zoo/stable_diffusion/models--runwayml--stable-diffusion-v1-5/snapshots/aa9ba505e1973ae5cd05f5aedd345178f52f8e6a"

    noise_scheduler = DDPMScheduler.from_pretrained(diffusion_model_cfg, subfolder="scheduler")
    vae = AutoencoderKL.from_pretrained(diffusion_model_cfg, subfolder="vae", revision=None)
    diffusion_unet = UNet2DConditionModel.from_pretrained(diffusion_model_cfg, subfolder="unet", revision=None)
    text_encoder = CLIPTextModel.from_pretrained(diffusion_model_cfg, subfolder="text_encoder", revision=None)
    clip_tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_cfg, subfolder="tokenizer", revision=None)

    from transformers import CLIPImageProcessor
    from transformers import CLIPVisionModel

    image_processor_clip = CLIPImageProcessor.from_pretrained(
        "/data/hypertext/runpei/model_zoo/llm/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff", local_files_only=True
    )
    vision_tower = CLIPVisionModel.from_pretrained(
        "/data/hypertext/runpei/model_zoo/llm/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff", local_files_only=True
    )
    vision_tower.to(dtype=torch.float32).cuda()
    vae.requires_grad_(False)
    diffusion_unet.requires_grad_(False)
    text_encoder.requires_grad_(False)
    vae = vae.cuda()
    diffusion_unet = diffusion_unet.cuda()
    text_encoder = text_encoder.cuda()
    clip_tokenizer = clip_tokenizer
    return noise_scheduler, vae, diffusion_unet, text_encoder, clip_tokenizer, vision_tower, image_processor_clip

@torch.inference_mode()
def generate_stream(tokenizer, model, params, device, image_idx, inp, context_len=2048, stream_interval=2):
    """Adapted from fastchat/serve/model_worker.py::generate_stream"""

    prompt = params["prompt"][1643:]
    l_prompt = len(prompt)
    temperature = float(params.get("temperature", 1.0))
    max_new_tokens = int(params.get("max_new_tokens", 2048))
    guide_w = int(params.get("guide_w", 5.0))
    stop_str = params.get("stop", None)
    # prompt = inp
    input_ids = tokenizer(inp).input_ids
    output_ids = list(tokenizer(inp).input_ids)

    max_src_len = context_len - max_new_tokens - 8
    # input_ids = input_ids[-max_src_len:]
    inner_img_idx = 0
    token = -1
    for i in range(max_new_tokens):
        if i == 0:
            out = model(input_ids=torch.as_tensor([input_ids], device=device), use_cache=True)
            logits = out.logits
            past_key_values = out.past_key_values
        else:
            if token == 32004:  # tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_START_TOKEN])[0]:
                dream_start_token = tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_START_TOKEN])[0]
                dream_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_END_TOKEN])[0]
                im_start_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN])[0]
                im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
                im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_END_TOKEN])[0]
                output_ids.append(dream_end_token)

                attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + model.model.latent_query_tokens.shape[1] + 2, device=device)

                im_embedding = model.model.embed_tokens(torch.as_tensor([[dream_start_token, dream_end_token]], device=device))
                query_embedding = model.model.latent_query_tokens
                input_embedding = torch.cat([im_embedding[..., :1, :], query_embedding, im_embedding[..., 1:, :]], 1)
                out = model(
                    input_ids=None,
                    use_cache=True,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    inputs_embeds=input_embedding,
                    past_key_values=past_key_values,
                )
                hidden_states = out.hidden_states[..., 1:-1, :]
                encoder_hidden_states = hidden_states
                encoder_hidden_states = model.condition_projector(encoder_hidden_states)

                attention_mask = torch.ones(1, model.get_model().latent_query_tokens.shape[1] + 2, device=device)
                u_out = model(
                    input_ids=None,
                    use_cache=True,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    inputs_embeds=input_embedding,
                )
                u_hidden_states = u_out.hidden_states[..., 1:-1, :]
                u_encoder_hidden_states = u_hidden_states
                u_encoder_hidden_states = model.condition_projector(u_encoder_hidden_states)

                height = model.diffusion_unet[0].config.sample_size
                width = model.diffusion_unet[0].config.sample_size
                channel = model.diffusion_unet[0].config.in_channels
                model.noise_scheduler.set_timesteps(100, device=encoder_hidden_states.device)

                latents = (
                    torch.randn((1, channel, height, width), dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
                    * model.noise_scheduler.init_noise_sigma
                )
                for t in model.noise_scheduler.timesteps:
                    with torch.no_grad():
                        latent_model_input = model.noise_scheduler.scale_model_input(latents, t)
                        latent_model_input = latent_model_input.repeat(2, 1, 1, 1)
                        noisy_pred = model.diffusion_unet[0](
                            latent_model_input, t, encoder_hidden_states=torch.cat([encoder_hidden_states, u_encoder_hidden_states], 0), return_dict=False
                        )[0]
                        noisy_pred = (1 + guide_w) * noisy_pred[:1] - guide_w * noisy_pred[1:]
                        latents = model.noise_scheduler.step(noisy_pred, t, latents)[0]

                print("generating image ...")
                image = model.vae[0].decode(latents / model.vae[0].config.scaling_factor, return_dict=False)[0]
                image = (image / 2 + 0.5).clamp(0, 1)
                image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
                image = Image.fromarray((image * 255).round().astype("uint8"))
                image.save("./tmp/image_" + str(image_idx) + "_" + str(inner_img_idx) + ".png")
                past_key_values = out.past_key_values

                image_input = model.image_processor_clip.preprocess(image, return_tensors="pt")["pixel_values"][0]
                out = model(
                    input_ids=torch.as_tensor([[im_start_token] + [im_patch_token] * 256 + [im_end_token]], device=device),
                    use_cache=True,
                    images=torch.as_tensor(torch.stack([image_input]), device=device),
                    output_hidden_states=True,
                    past_key_values=past_key_values,
                )

                inner_img_idx = inner_img_idx + 1
                logits = out.logits
                past_key_values = out.past_key_values
            else:
                attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device)
                out = model(
                    input_ids=torch.as_tensor([[token]], device=device),
                    use_cache=True,
                    attention_mask=attention_mask,
                    output_hidden_states=True,
                    past_key_values=past_key_values,
                )
                logits = out.logits
                past_key_values = out.past_key_values
        if logits is not None:
            last_token_logits = logits[0][-1]
            # if token != -1 and tokenizer.decode([token])[0] == '.':
            #     print(torch.softmax(last_token_logits, dim=-1)[32004], torch.softmax(last_token_logits, dim=-1).max())
            if temperature < 1e-4:
                token = int(torch.argmax(last_token_logits))
            else:
                probs = torch.softmax(last_token_logits / temperature, dim=-1)
                token = int(torch.multinomial(probs, num_samples=1))
            output_ids.append(token)

            if token == tokenizer.eos_token_id:
                stopped = True
            else:
                stopped = False
        if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
            # output_ids = [ids if ids < 32000 else 0 for ids in output_ids]
            output = tokenizer.decode(output_ids, skip_special_tokens=True)

            pos = output.rfind(stop_str, l_prompt)
            if pos != -1:
                output = output[:pos]
                stopped = True
            yield output

        if stopped:
            break

    del past_key_values

# TODO
def demo_project_signature():
    pass

def demo_meta():
    st.title("DreamLLM")
    st.markdown(
        f"""
        version {__version__}

        Let your LLM dream and see the visual world.

        example 1: Write an article about Xiangshan.

        example 2: Write a review about your experience at Craythorne's Hotel, focusing on the restaurant and bar. Describe the hotel's location and atmosphere, along with the team's friendly and professional demeanor. Discuss the menu offerings, including popular favorites, Kiwi fare, and options for various dietary preferences. Share your own dining experience, including any disappointments with the dishes' quality, temperature, and value for money. Mention your return visit and any changes in the menu or dish quality. Include any pleasant surprises, like happy hour deals or birthday celebrations, and evaluate the overall service, menu variety, and food quality.
        """
    )

def demo_liscence():
    st.markdown(
        f"""
        ## Terms of use
        By using this service, users are required to agree to the following terms: 
        The service is a research preview intended for non-commercial use only. 
        It only provides limited safety measures and may generate offensive content. 
        It must not be used for any illegal, harmful, violent, racist, or sexual purposes. 
        The service may collect user dialogue data for future research. 
        Please click the “Flag” button if you get any inappropriate answer! 
        We will collect those to keep improving our moderator. 
        For an optimal experience, please use desktop computers for this demo, 
        as mobile devices may compromise its quality.
        """
    )
    st.markdown(
        f"""
        ## License
        The service is a research preview intended for non-commercial use only, 
        subject to the model License of LLaMA, Terms of Use of the data generated by OpenAI, and Privacy Practices of ShareGPT. 
        Please contact us if you find any potential violation.
        """
    )

def reset_chat_history():
    """
    This function is used to reset the chat history.
    """
    st.session_state["generated"] = [["Hey there, I'm Chatty DreamLLM. I am happy to talk with you!"]]
    st.session_state["past"] = ["Hi. It's me again!"]
    # st.session_state["input"] = ""
    st.session_state["stored_session"] = []
    st.session_state["messages"] = [("Hello! I'm a chatbot that dreams and see the visual world. How can I help you?")]

def build_model(args):
    model_name = args.model_name
    num_gpus = args.num_gpus

    # Model
    if args.device == "cuda":
        kwargs = {"torch_dtype": torch.float32}
        if num_gpus == "auto":
            kwargs["device_map"] = "auto"
        else:
            num_gpus = int(num_gpus)
            if num_gpus != 1:
                kwargs.update(
                    {
                        "device_map": "auto",
                        "max_memory": {i: "13GiB" for i in range(num_gpus)},
                    }
                )
    elif args.device == "cpu":
        kwargs = {}
    else:
        raise ValueError(f"Invalid device: {args.device}")

    tokenizer = LlamaTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, **kwargs)
    # vqgan_model = get_vqgan("vqgan_gumbel_f8_8192")
    # model.model.vision_tokenizer = [vqgan_model]
    model.get_model().config.im_start_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN])[0]
    model.get_model().config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_END_TOKEN])[0]
    model.get_model().config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
    model.get_model().config.dream_start_token = tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_START_TOKEN])[0]
    model.get_model().config.dream_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_END_TOKEN])[0]

    noise_scheduler, vae, diffusion_unet, text_encoder, clip_tokenizer, vision_tower, image_processor_clip = build_diffusion_blocks()
    model.noise_scheduler = noise_scheduler
    model.vae = [vae]
    model.diffusion_unet = [diffusion_unet]
    model.text_encoder = text_encoder
    model.clip_tokenizer = clip_tokenizer
    model.get_model().vision_tower = [vision_tower]
    model.image_processor_clip = image_processor_clip

    if args.device == "cuda" and num_gpus == 1:
        model.cuda()
    return model, tokenizer

def format_message(text):
    """
    This function is used to format the messages in the chatbot UI.

    Parameters:
    text (str): The text to be formatted.
    """
    text_blocks = re.split(r"```[\s\S]*?```", text)
    code_blocks = re.findall(r"```([\s\S]*?)```", text)

    text_blocks = [html.escape(block) for block in text_blocks]

    formatted_text = ""
    for i in range(len(text_blocks)):
        formatted_text += text_blocks[i].replace("\n", "<br>")
        if i < len(code_blocks):
            formatted_text += f'<pre style="white-space: pre-wrap; word-wrap: break-word;"><code>{html.escape(code_blocks[i])}</code></pre>'

    return formatted_text

def format_dream_messages(text):
    """
    This function is used to process the interleaved outputs

    Args:
        text (_type_): _description_
    """
    # NOTE this is currently hardcoded!
    return text.split(" <dream_start> <dream_end> ")

def find_dream_tokens(tokens, dream_token="<dream_start>"):
    start = 0
    indexes = []
    for index, text in enumerate(tokens):
        if text == dream_token:
            indexes.append(index)

    return indexes

def message_func(texts, is_user=False, images=None):
    """
    This function is used to display the messages in the chatbot UI.

    Parameters:
    text (str): The text to be displayed.
    is_user (bool): Whether the message is from the user or the chatbot.
    key (str): The key to be used for the message.
    avatar_style (str): The style of the avatar to be used.
    """
    if is_user:
        avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=ShortHairShortWaved&accessoriesType=Round&hairColor=BrownDark&facialHairType=Blank&clotheType=BlazerSweater&eyeType=Default&eyebrowType=DefaultNatural&mouthType=Smile&skinColor=Light"
        message_alignment = "flex-end"
        message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)"
        avatar_class = "user-avatar"
        text = texts  # TODO more inputs
        st.write(
            f"""
                <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
                    <div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%;">
                        {text}
                    </div>
                        <img src="{avatar_url}" class="{avatar_class}" alt="avatar" width="40"/>
                </div>
                """,
            unsafe_allow_html=True,
        )
    else:
        # avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
        avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=Hat&accessoriesType=Wayfarers&facialHairType=MoustacheFancy&facialHairColor=BrownDark&clotheType=Hoodie&clotheColor=Pink&eyeType=Default&eyebrowType=DefaultNatural&mouthType=Tongue&skinColor=Light"
        message_alignment = "flex-start"
        message_bg_color = "#71797E"
        avatar_class = "bot-avatar"
        dream_indexes = find_dream_tokens(texts)
        # text_merge = format_dream_messages(texts)
        dream_cnt = 0
        if len(dream_indexes) > 0:
            assert len(dream_indexes) == len(images)
            next_dream = dream_indexes[dream_cnt]
        else:
            next_dream = len(texts) - 1

        gen_text = True
        for index, text in enumerate(texts):
            if images is None:
                st.write(
                    f"""
                        <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
                            <img src="{avatar_url}" class="{avatar_class}" alt="avatar" width="40"/>
                            <div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%;">
                                {text} \n  </div>
                        </div>
                        """,
                    unsafe_allow_html=True,
                )
                break
            if text == "<dream_end>":
                continue
            elif text != "<dream_start>" and gen_text:
                # text
                _this_text_list = []
                if next_dream >= len(texts):
                    continue
                for i in range(index, next_dream):
                    _this_text_list.append(texts[i])
                _this_text = format_message(" ".join(_this_text_list))
                st.write(
                    f"""
                        <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
                            <img src="{avatar_url}" class="{avatar_class}" alt="avatar" width="40"/>
                            <div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%;">
                                {_this_text} \n  </div>
                        </div>
                        """,
                    unsafe_allow_html=True,
                )
                gen_text = False
            elif index in dream_indexes:
                # image
                if images is not None and (dream_cnt + 1) <= len(images):
                    image = Image.open(images[dream_cnt])
                    # st.write(
                    #     f"""
                    #         <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
                    #             <img src="{avatar_url}" class="{avatar_class}" alt="avatar" width="40"/>
                    #         </div>
                    #     """,
                    #     unsafe_allow_html=True,
                    # )
                    st.image(image=image, width=224)
                    dream_cnt += 1
                    if dream_cnt < len(dream_indexes):
                        next_dream = dream_indexes[dream_cnt]
                    else:
                        next_dream = len(texts) + 1
                    gen_text = True

def update_progress_bar(value, prefix, progress_bar=None):
    if progress_bar is None:
        progress_bar = st.empty()

    key = f"{prefix}_progress_bar_value"
    if key not in st.session_state:
        st.session_state[key] = 0

    st.session_state[key] = value
    progress_bar.progress(st.session_state[key])
    if value == 100:
        st.session_state[key] = 0
        progress_bar.empty()

def chat_dreamllm(args, model, tokenizer, image_idx=0, message=None):
    # Chat
    conv = conv_templates[args.conv_template].copy()
    if message is not None:
        inp = message
    else:
        try:
            inp = input(f"{conv.roles[0]}: ")
        except EOFError:
            inp = ""
        if not inp:
            print("exit...")

    if args.conv_template == "dream_conv":
        conv.append_message(conv.roles[0], inp)
    else:
        conv.append_message(conv.roles[0], "nothing")
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    params = {
        "model": args.model_name,
        "prompt": prompt,
        "temperature": args.temperature,
        "max_new_tokens": args.max_new_tokens,
        "guide_w": args.guide_w,
        "stop": conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2,
    }

    print(f"{conv.roles[1]}: ", end="", flush=True)
    pre = 0
    final_images = []
    for outputs in generate_stream(tokenizer, model, params, args.device, image_idx, inp):
        outputs = outputs[len(prompt) + 1 :].strip()
        outputs = outputs.split(" ")
        now = len(outputs)
        if now - 1 > pre:
            print(" ".join(outputs[pre : now - 1]), end=" ", flush=True)
            pre = now - 1
    print(" ".join(outputs[pre:]), flush=True)

    dream_indexes = find_dream_tokens(outputs)
    for _idx in range(len(dream_indexes)):
        final_images.append("./tmp/image_" + str(image_idx) + "_" + str(_idx) + ".png")

    conv.messages[-1][-1] = (outputs, final_images)
    if args.debug:
        print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
    return (outputs, final_images)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # parser.add_argument("--model-name", type=str, default="/mnt/host0/tmp_models/mmc4sft_model/")
    # parser.add_argument("--model-name", type=str, default="/mnt/host0/tmp_models/interleave_stage2_9kiter/")
    parser.add_argument("--model-name", type=str, default="/mnt/host0/interleave_mmc4_blippair_sft_6kiter/")

    parser.add_argument("--img-dir", type=str, default="./tmp/")
    parser.add_argument("--num-gpus", type=str, default="1")
    parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
    parser.add_argument("--conv-template", type=str, default="dream_conv")
    # parser.add_argument("--conv-template", type=str, default="v1")
    parser.add_argument("--temperature", type=float, default=0.7)
    parser.add_argument("--max-new-tokens", type=int, default=2048)
    parser.add_argument("--guide_w", type=float, default=3.0)
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    if not os.path.isdir(args.img_dir):
        os.mkdir(args.img_dir)

    model, tokenizer = build_model(args)
    print("✨✨✨ Model loaded! 🤩Let's play!🤩")

    # 1. title, slogan, orgnaization
    demo_title(page_title="DreamLLM", page_icon=":clown_face:")

    # 2. sidebar layout
    with st.sidebar:
        demo_meta()
        demo_spacer()

    # 3. define chat_window
    if "generated" not in st.session_state:
        st.session_state["generated"] = [["Hey there, I'm DreamLLM. I am happy to talk with you!"]]
    if "past" not in st.session_state:
        st.session_state["past"] = ["Hey!"]
    if "input" not in st.session_state:
        st.session_state["input"] = ""
    if "stored_session" not in st.session_state:
        st.session_state["stored_session"] = []

    if "messages" not in st.session_state:
        st.session_state["messages"] = [("Hello! I'm a chatbot that dreams and see the visual world. How can I help you?")]

    if "query_count" not in st.session_state:
        st.session_state["query_count"] = 0

    RESET = True
    # container for chat history
    messages_container = st.container()
    with st.form(key="my_form"):
        query = st.text_input(
            "Query: ",
            key="input",
            value="",
            placeholder="Enter what you want to talk here...",
            label_visibility="hidden",
        )
        submit_button = st.form_submit_button(label="Submit")
    col1, col2 = st.columns([1, 3.2])
    reset_button = col1.button("Reset Chat History")

    if reset_button or st.session_state["query_count"] >= MAX_INPUTS and RESET:
        RESET = False
        st.session_state["query_count"] = 0
        reset_chat_history()

    if "messages" not in st.session_state:
        st.session_state["messages"] = []

    image_idx = 0
    if len(query) > 2 and submit_button:
        submit_progress_bar = st.empty()
        messages = st.session_state["messages"]
        update_progress_bar(33, "submit", submit_progress_bar)

        outputs = chat_dreamllm(args=args, model=model, tokenizer=tokenizer, image_idx=image_idx, message=query)

        result = {"answer": outputs[0], "images": outputs[1]}
        update_progress_bar(66, "submit", submit_progress_bar)
        st.session_state["query_count"] += 1
        messages.append((query, result["answer"]))
        st.session_state.past.append(query)
        st.session_state.generated.append((result["answer"], result["images"]))
        update_progress_bar(100, "submit", submit_progress_bar)
    image_idx += 1

    with messages_container:
        if st.session_state["generated"]:
            for i in range(len(st.session_state["generated"])):
                message_func(st.session_state["past"][i], is_user=True)
                if isinstance(st.session_state["generated"][i], tuple):
                    message_func(st.session_state["generated"][i][0], images=st.session_state["generated"][i][1])
                else:
                    message_func(st.session_state["generated"][i])

    if st.session_state["query_count"] == MAX_INPUTS and RESET:
        st.warning("You have reached the maximum number of inputs. The chat history will be cleared after the next input.")

    col2.markdown(
        f'<div style="line-height: 2.5;">{st.session_state["query_count"]}/{MAX_INPUTS}</div>',
        unsafe_allow_html=True,
    )

    st.markdown('<div id="input-container-placeholder"></div>', unsafe_allow_html=True)

    st.components.v1.html(
        """
        <script>
        window.addEventListener('load', function() {
            const inputContainer = document.querySelector('.stTextInput');
            const inputContainerPlaceholder = document.getElementById('input-container-placeholder');
            inputContainer.id = 'input-container';
            inputContainerPlaceholder.appendChild(inputContainer);
            document.getElementById("input").focus();
        });
        </script>
        """,
        height=0,
    )