Open Y-aang opened 4 months ago
Hi @Y-aang,
Thank you for your interest! I just came back from Austria. I will try to find some time to write this script this month.
Dear Ruipeng,
Thank you for your reply. I hope you had a great time, and your efforts and contributions to DreamLLM are greatly appreciated. Wishing you all the best!
Regards, Y-aang
2024年5月14日 17:23,Runpei Dong @.***> 写道:
Hi @Y-aang https://github.com/Y-aang,
Thank you for your interest! I just came back from Austria. I will try to find some time to write this script this month.
— Reply to this email directly, view it on GitHub https://github.com/RunpeiDong/DreamLLM/issues/15#issuecomment-2109706164, or unsubscribe https://github.com/notifications/unsubscribe-auth/AWPQEVLZ4TRRENEL6Y43ANLZCHJXLAVCNFSM6AAAAABHLRYLL6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMBZG4YDMMJWGQ. You are receiving this because you were mentioned.
Hi @RunpeiDong ,
Thanks so much for your excellent work. I wonder would you please share the code for interleaved generation? Thanks!
Thank you for your awesome work! I'm looking forward to sharing the code for interleaved generation!
Hi, is the module coming?
Hi all, sorry for the delay. But I still don't have the time to implement this. However, I can show you guys one old script from the earlier version of the codebase, which is a streamlit script and was used for internal development. This script cannot be run now since the API has changed. But I hope you can try to read the code and see the logits, which is very clear how interleaved contents are generated.
Here is the script:
# ------------------------------------------------------------------------------------------------
# Copyright (c) 2023 DreamLLM Authors. All rights reserved.
# Designed by Runpei Dong
# Last updated: 03/2023
# ------------------------------------------------------------------------------------------------
__version__ = "0.0.0"
app_name = "DreamLLM"
import os
import re
import html
from PIL import Image
import streamlit as st
import torch
import argparse
from transformers import LlamaTokenizer
from dreamllm.utils.constants import *
from dreamllm.utils.conversation import conv_templates, SeparatorStyle
from dreamllm.model.llama.modeling_dreamllm_llama_stablediffusion import DreamLLMForCausalLM as AutoModelForCausalLM
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_DREAM_START_TOKEN = "<dream_start>" # NOTE make llm dream!
DEFAULT_DREAM_END_TOKEN = "<dream_end>" # NOTE make llm dream!
MAX_INPUTS = 1
chat_history = []
def demo_title(page_title="DreamLLM", page_icon=":clown_face:"):
st.set_page_config(page_title=page_title, page_icon=page_icon, layout="wide")
st.markdown("<h1 style='text-align: center;'> 🐑 DreamLLM: Dream, See, and Talk</h1>", unsafe_allow_html=True)
st.markdown("<h5 style='text-align: center;'> <I>\"Do Androids Dream of Electric Sheep?\"</I> —— Philip K. Dick", unsafe_allow_html=True)
st.markdown("<h5 style='text-align: center; color: rgba(0, 0, 0, 0.60);'> Foundation Model Department, MEGVII Technology </h5>", unsafe_allow_html=True)
def demo_spacer(n=2, line=False, next_n=0):
for _ in range(n):
st.write("")
if line:
st.tabs([" "])
for _ in range(next_n):
st.write("")
def build_diffusion_blocks():
# diffusion_model_cfg = "runwayml/stable-diffusion-v1-5"
diffusion_model_cfg = "/data/model_zoo/stable_diffusion/models--runwayml--stable-diffusion-v1-5/snapshots/aa9ba505e1973ae5cd05f5aedd345178f52f8e6a"
noise_scheduler = DDPMScheduler.from_pretrained(diffusion_model_cfg, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(diffusion_model_cfg, subfolder="vae", revision=None)
diffusion_unet = UNet2DConditionModel.from_pretrained(diffusion_model_cfg, subfolder="unet", revision=None)
text_encoder = CLIPTextModel.from_pretrained(diffusion_model_cfg, subfolder="text_encoder", revision=None)
clip_tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_cfg, subfolder="tokenizer", revision=None)
from transformers import CLIPImageProcessor
from transformers import CLIPVisionModel
image_processor_clip = CLIPImageProcessor.from_pretrained(
"/data/hypertext/runpei/model_zoo/llm/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff", local_files_only=True
)
vision_tower = CLIPVisionModel.from_pretrained(
"/data/hypertext/runpei/model_zoo/llm/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff", local_files_only=True
)
vision_tower.to(dtype=torch.float32).cuda()
vae.requires_grad_(False)
diffusion_unet.requires_grad_(False)
text_encoder.requires_grad_(False)
vae = vae.cuda()
diffusion_unet = diffusion_unet.cuda()
text_encoder = text_encoder.cuda()
clip_tokenizer = clip_tokenizer
return noise_scheduler, vae, diffusion_unet, text_encoder, clip_tokenizer, vision_tower, image_processor_clip
@torch.inference_mode()
def generate_stream(tokenizer, model, params, device, image_idx, inp, context_len=2048, stream_interval=2):
"""Adapted from fastchat/serve/model_worker.py::generate_stream"""
prompt = params["prompt"][1643:]
l_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
guide_w = int(params.get("guide_w", 5.0))
stop_str = params.get("stop", None)
# prompt = inp
input_ids = tokenizer(inp).input_ids
output_ids = list(tokenizer(inp).input_ids)
max_src_len = context_len - max_new_tokens - 8
# input_ids = input_ids[-max_src_len:]
inner_img_idx = 0
token = -1
for i in range(max_new_tokens):
if i == 0:
out = model(input_ids=torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
if token == 32004: # tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_START_TOKEN])[0]:
dream_start_token = tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_START_TOKEN])[0]
dream_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_END_TOKEN])[0]
im_start_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN])[0]
im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_END_TOKEN])[0]
output_ids.append(dream_end_token)
attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + model.model.latent_query_tokens.shape[1] + 2, device=device)
im_embedding = model.model.embed_tokens(torch.as_tensor([[dream_start_token, dream_end_token]], device=device))
query_embedding = model.model.latent_query_tokens
input_embedding = torch.cat([im_embedding[..., :1, :], query_embedding, im_embedding[..., 1:, :]], 1)
out = model(
input_ids=None,
use_cache=True,
attention_mask=attention_mask,
output_hidden_states=True,
inputs_embeds=input_embedding,
past_key_values=past_key_values,
)
hidden_states = out.hidden_states[..., 1:-1, :]
encoder_hidden_states = hidden_states
encoder_hidden_states = model.condition_projector(encoder_hidden_states)
attention_mask = torch.ones(1, model.get_model().latent_query_tokens.shape[1] + 2, device=device)
u_out = model(
input_ids=None,
use_cache=True,
attention_mask=attention_mask,
output_hidden_states=True,
inputs_embeds=input_embedding,
)
u_hidden_states = u_out.hidden_states[..., 1:-1, :]
u_encoder_hidden_states = u_hidden_states
u_encoder_hidden_states = model.condition_projector(u_encoder_hidden_states)
height = model.diffusion_unet[0].config.sample_size
width = model.diffusion_unet[0].config.sample_size
channel = model.diffusion_unet[0].config.in_channels
model.noise_scheduler.set_timesteps(100, device=encoder_hidden_states.device)
latents = (
torch.randn((1, channel, height, width), dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device)
* model.noise_scheduler.init_noise_sigma
)
for t in model.noise_scheduler.timesteps:
with torch.no_grad():
latent_model_input = model.noise_scheduler.scale_model_input(latents, t)
latent_model_input = latent_model_input.repeat(2, 1, 1, 1)
noisy_pred = model.diffusion_unet[0](
latent_model_input, t, encoder_hidden_states=torch.cat([encoder_hidden_states, u_encoder_hidden_states], 0), return_dict=False
)[0]
noisy_pred = (1 + guide_w) * noisy_pred[:1] - guide_w * noisy_pred[1:]
latents = model.noise_scheduler.step(noisy_pred, t, latents)[0]
print("generating image ...")
image = model.vae[0].decode(latents / model.vae[0].config.scaling_factor, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
image.save("./tmp/image_" + str(image_idx) + "_" + str(inner_img_idx) + ".png")
past_key_values = out.past_key_values
image_input = model.image_processor_clip.preprocess(image, return_tensors="pt")["pixel_values"][0]
out = model(
input_ids=torch.as_tensor([[im_start_token] + [im_patch_token] * 256 + [im_end_token]], device=device),
use_cache=True,
images=torch.as_tensor(torch.stack([image_input]), device=device),
output_hidden_states=True,
past_key_values=past_key_values,
)
inner_img_idx = inner_img_idx + 1
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device)
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
attention_mask=attention_mask,
output_hidden_states=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
if logits is not None:
last_token_logits = logits[0][-1]
# if token != -1 and tokenizer.decode([token])[0] == '.':
# print(torch.softmax(last_token_logits, dim=-1)[32004], torch.softmax(last_token_logits, dim=-1).max())
if temperature < 1e-4:
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token == tokenizer.eos_token_id:
stopped = True
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
# output_ids = [ids if ids < 32000 else 0 for ids in output_ids]
output = tokenizer.decode(output_ids, skip_special_tokens=True)
pos = output.rfind(stop_str, l_prompt)
if pos != -1:
output = output[:pos]
stopped = True
yield output
if stopped:
break
del past_key_values
# TODO
def demo_project_signature():
pass
def demo_meta():
st.title("DreamLLM")
st.markdown(
f"""
version {__version__}
Let your LLM dream and see the visual world.
example 1: Write an article about Xiangshan.
example 2: Write a review about your experience at Craythorne's Hotel, focusing on the restaurant and bar. Describe the hotel's location and atmosphere, along with the team's friendly and professional demeanor. Discuss the menu offerings, including popular favorites, Kiwi fare, and options for various dietary preferences. Share your own dining experience, including any disappointments with the dishes' quality, temperature, and value for money. Mention your return visit and any changes in the menu or dish quality. Include any pleasant surprises, like happy hour deals or birthday celebrations, and evaluate the overall service, menu variety, and food quality.
"""
)
def demo_liscence():
st.markdown(
f"""
## Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only.
It only provides limited safety measures and may generate offensive content.
It must not be used for any illegal, harmful, violent, racist, or sexual purposes.
The service may collect user dialogue data for future research.
Please click the “Flag” button if you get any inappropriate answer!
We will collect those to keep improving our moderator.
For an optimal experience, please use desktop computers for this demo,
as mobile devices may compromise its quality.
"""
)
st.markdown(
f"""
## License
The service is a research preview intended for non-commercial use only,
subject to the model License of LLaMA, Terms of Use of the data generated by OpenAI, and Privacy Practices of ShareGPT.
Please contact us if you find any potential violation.
"""
)
def reset_chat_history():
"""
This function is used to reset the chat history.
"""
st.session_state["generated"] = [["Hey there, I'm Chatty DreamLLM. I am happy to talk with you!"]]
st.session_state["past"] = ["Hi. It's me again!"]
# st.session_state["input"] = ""
st.session_state["stored_session"] = []
st.session_state["messages"] = [("Hello! I'm a chatbot that dreams and see the visual world. How can I help you?")]
def build_model(args):
model_name = args.model_name
num_gpus = args.num_gpus
# Model
if args.device == "cuda":
kwargs = {"torch_dtype": torch.float32}
if num_gpus == "auto":
kwargs["device_map"] = "auto"
else:
num_gpus = int(num_gpus)
if num_gpus != 1:
kwargs.update(
{
"device_map": "auto",
"max_memory": {i: "13GiB" for i in range(num_gpus)},
}
)
elif args.device == "cpu":
kwargs = {}
else:
raise ValueError(f"Invalid device: {args.device}")
tokenizer = LlamaTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, **kwargs)
# vqgan_model = get_vqgan("vqgan_gumbel_f8_8192")
# model.model.vision_tokenizer = [vqgan_model]
model.get_model().config.im_start_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN])[0]
model.get_model().config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_END_TOKEN])[0]
model.get_model().config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
model.get_model().config.dream_start_token = tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_START_TOKEN])[0]
model.get_model().config.dream_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_DREAM_END_TOKEN])[0]
noise_scheduler, vae, diffusion_unet, text_encoder, clip_tokenizer, vision_tower, image_processor_clip = build_diffusion_blocks()
model.noise_scheduler = noise_scheduler
model.vae = [vae]
model.diffusion_unet = [diffusion_unet]
model.text_encoder = text_encoder
model.clip_tokenizer = clip_tokenizer
model.get_model().vision_tower = [vision_tower]
model.image_processor_clip = image_processor_clip
if args.device == "cuda" and num_gpus == 1:
model.cuda()
return model, tokenizer
def format_message(text):
"""
This function is used to format the messages in the chatbot UI.
Parameters:
text (str): The text to be formatted.
"""
text_blocks = re.split(r"```[\s\S]*?```", text)
code_blocks = re.findall(r"```([\s\S]*?)```", text)
text_blocks = [html.escape(block) for block in text_blocks]
formatted_text = ""
for i in range(len(text_blocks)):
formatted_text += text_blocks[i].replace("\n", "<br>")
if i < len(code_blocks):
formatted_text += f'<pre style="white-space: pre-wrap; word-wrap: break-word;"><code>{html.escape(code_blocks[i])}</code></pre>'
return formatted_text
def format_dream_messages(text):
"""
This function is used to process the interleaved outputs
Args:
text (_type_): _description_
"""
# NOTE this is currently hardcoded!
return text.split(" <dream_start> <dream_end> ")
def find_dream_tokens(tokens, dream_token="<dream_start>"):
start = 0
indexes = []
for index, text in enumerate(tokens):
if text == dream_token:
indexes.append(index)
return indexes
def message_func(texts, is_user=False, images=None):
"""
This function is used to display the messages in the chatbot UI.
Parameters:
text (str): The text to be displayed.
is_user (bool): Whether the message is from the user or the chatbot.
key (str): The key to be used for the message.
avatar_style (str): The style of the avatar to be used.
"""
if is_user:
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=ShortHairShortWaved&accessoriesType=Round&hairColor=BrownDark&facialHairType=Blank&clotheType=BlazerSweater&eyeType=Default&eyebrowType=DefaultNatural&mouthType=Smile&skinColor=Light"
message_alignment = "flex-end"
message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)"
avatar_class = "user-avatar"
text = texts # TODO more inputs
st.write(
f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%;">
{text}
</div>
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" width="40"/>
</div>
""",
unsafe_allow_html=True,
)
else:
# avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=Hat&accessoriesType=Wayfarers&facialHairType=MoustacheFancy&facialHairColor=BrownDark&clotheType=Hoodie&clotheColor=Pink&eyeType=Default&eyebrowType=DefaultNatural&mouthType=Tongue&skinColor=Light"
message_alignment = "flex-start"
message_bg_color = "#71797E"
avatar_class = "bot-avatar"
dream_indexes = find_dream_tokens(texts)
# text_merge = format_dream_messages(texts)
dream_cnt = 0
if len(dream_indexes) > 0:
assert len(dream_indexes) == len(images)
next_dream = dream_indexes[dream_cnt]
else:
next_dream = len(texts) - 1
gen_text = True
for index, text in enumerate(texts):
if images is None:
st.write(
f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" width="40"/>
<div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%;">
{text} \n </div>
</div>
""",
unsafe_allow_html=True,
)
break
if text == "<dream_end>":
continue
elif text != "<dream_start>" and gen_text:
# text
_this_text_list = []
if next_dream >= len(texts):
continue
for i in range(index, next_dream):
_this_text_list.append(texts[i])
_this_text = format_message(" ".join(_this_text_list))
st.write(
f"""
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" width="40"/>
<div style="background: {message_bg_color}; color: white; border-radius: 20px; padding: 10px; margin-right: 5px; max-width: 75%;">
{_this_text} \n </div>
</div>
""",
unsafe_allow_html=True,
)
gen_text = False
elif index in dream_indexes:
# image
if images is not None and (dream_cnt + 1) <= len(images):
image = Image.open(images[dream_cnt])
# st.write(
# f"""
# <div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
# <img src="{avatar_url}" class="{avatar_class}" alt="avatar" width="40"/>
# </div>
# """,
# unsafe_allow_html=True,
# )
st.image(image=image, width=224)
dream_cnt += 1
if dream_cnt < len(dream_indexes):
next_dream = dream_indexes[dream_cnt]
else:
next_dream = len(texts) + 1
gen_text = True
def update_progress_bar(value, prefix, progress_bar=None):
if progress_bar is None:
progress_bar = st.empty()
key = f"{prefix}_progress_bar_value"
if key not in st.session_state:
st.session_state[key] = 0
st.session_state[key] = value
progress_bar.progress(st.session_state[key])
if value == 100:
st.session_state[key] = 0
progress_bar.empty()
def chat_dreamllm(args, model, tokenizer, image_idx=0, message=None):
# Chat
conv = conv_templates[args.conv_template].copy()
if message is not None:
inp = message
else:
try:
inp = input(f"{conv.roles[0]}: ")
except EOFError:
inp = ""
if not inp:
print("exit...")
if args.conv_template == "dream_conv":
conv.append_message(conv.roles[0], inp)
else:
conv.append_message(conv.roles[0], "nothing")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
params = {
"model": args.model_name,
"prompt": prompt,
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"guide_w": args.guide_w,
"stop": conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2,
}
print(f"{conv.roles[1]}: ", end="", flush=True)
pre = 0
final_images = []
for outputs in generate_stream(tokenizer, model, params, args.device, image_idx, inp):
outputs = outputs[len(prompt) + 1 :].strip()
outputs = outputs.split(" ")
now = len(outputs)
if now - 1 > pre:
print(" ".join(outputs[pre : now - 1]), end=" ", flush=True)
pre = now - 1
print(" ".join(outputs[pre:]), flush=True)
dream_indexes = find_dream_tokens(outputs)
for _idx in range(len(dream_indexes)):
final_images.append("./tmp/image_" + str(image_idx) + "_" + str(_idx) + ".png")
conv.messages[-1][-1] = (outputs, final_images)
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
return (outputs, final_images)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# parser.add_argument("--model-name", type=str, default="/mnt/host0/tmp_models/mmc4sft_model/")
# parser.add_argument("--model-name", type=str, default="/mnt/host0/tmp_models/interleave_stage2_9kiter/")
parser.add_argument("--model-name", type=str, default="/mnt/host0/interleave_mmc4_blippair_sft_6kiter/")
parser.add_argument("--img-dir", type=str, default="./tmp/")
parser.add_argument("--num-gpus", type=str, default="1")
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
parser.add_argument("--conv-template", type=str, default="dream_conv")
# parser.add_argument("--conv-template", type=str, default="v1")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-new-tokens", type=int, default=2048)
parser.add_argument("--guide_w", type=float, default=3.0)
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
if not os.path.isdir(args.img_dir):
os.mkdir(args.img_dir)
model, tokenizer = build_model(args)
print("✨✨✨ Model loaded! 🤩Let's play!🤩")
# 1. title, slogan, orgnaization
demo_title(page_title="DreamLLM", page_icon=":clown_face:")
# 2. sidebar layout
with st.sidebar:
demo_meta()
demo_spacer()
# 3. define chat_window
if "generated" not in st.session_state:
st.session_state["generated"] = [["Hey there, I'm DreamLLM. I am happy to talk with you!"]]
if "past" not in st.session_state:
st.session_state["past"] = ["Hey!"]
if "input" not in st.session_state:
st.session_state["input"] = ""
if "stored_session" not in st.session_state:
st.session_state["stored_session"] = []
if "messages" not in st.session_state:
st.session_state["messages"] = [("Hello! I'm a chatbot that dreams and see the visual world. How can I help you?")]
if "query_count" not in st.session_state:
st.session_state["query_count"] = 0
RESET = True
# container for chat history
messages_container = st.container()
with st.form(key="my_form"):
query = st.text_input(
"Query: ",
key="input",
value="",
placeholder="Enter what you want to talk here...",
label_visibility="hidden",
)
submit_button = st.form_submit_button(label="Submit")
col1, col2 = st.columns([1, 3.2])
reset_button = col1.button("Reset Chat History")
if reset_button or st.session_state["query_count"] >= MAX_INPUTS and RESET:
RESET = False
st.session_state["query_count"] = 0
reset_chat_history()
if "messages" not in st.session_state:
st.session_state["messages"] = []
image_idx = 0
if len(query) > 2 and submit_button:
submit_progress_bar = st.empty()
messages = st.session_state["messages"]
update_progress_bar(33, "submit", submit_progress_bar)
outputs = chat_dreamllm(args=args, model=model, tokenizer=tokenizer, image_idx=image_idx, message=query)
result = {"answer": outputs[0], "images": outputs[1]}
update_progress_bar(66, "submit", submit_progress_bar)
st.session_state["query_count"] += 1
messages.append((query, result["answer"]))
st.session_state.past.append(query)
st.session_state.generated.append((result["answer"], result["images"]))
update_progress_bar(100, "submit", submit_progress_bar)
image_idx += 1
with messages_container:
if st.session_state["generated"]:
for i in range(len(st.session_state["generated"])):
message_func(st.session_state["past"][i], is_user=True)
if isinstance(st.session_state["generated"][i], tuple):
message_func(st.session_state["generated"][i][0], images=st.session_state["generated"][i][1])
else:
message_func(st.session_state["generated"][i])
if st.session_state["query_count"] == MAX_INPUTS and RESET:
st.warning("You have reached the maximum number of inputs. The chat history will be cleared after the next input.")
col2.markdown(
f'<div style="line-height: 2.5;">{st.session_state["query_count"]}/{MAX_INPUTS}</div>',
unsafe_allow_html=True,
)
st.markdown('<div id="input-container-placeholder"></div>', unsafe_allow_html=True)
st.components.v1.html(
"""
<script>
window.addEventListener('load', function() {
const inputContainer = document.querySelector('.stTextInput');
const inputContainerPlaceholder = document.getElementById('input-container-placeholder');
inputContainer.id = 'input-container';
inputContainerPlaceholder.appendChild(inputContainer);
document.getElementById("input").focus();
});
</script>
""",
height=0,
)
Hello,
This work is truly inspiring and brilliant!
Thank you very much for making it open source. I noticed that the inference script in the repository can only generate images based on text. Have you written a script in the past that generates interleaved graph-text content? It would be very kind of you to provide it if possible.