haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
20.12k stars 2.21k forks source link

[Question] Possibility of using LLaVA in Python 3.6? #417

Open TikaToka opened 1 year ago

TikaToka commented 1 year ago

Question

Hellu, @haotian-liu, Thank you for sharing your amazing work!

Is there a way to use LLaVA in Python3.6 as It requires 3.8 as an minimum, or run it using huggingface pipeline? I am conducting a experiment, but the baseline code is based on python 3.6, so adapting it is super difficult. (Env compatibility problem + code compatibility problem)

Thank you in advance!

haotian-liu commented 1 year ago

Do you need training? If you just need inference, you can launch llava worker in one of the env, and interact with LLaVA using http request, like what we do in gradio demo

TikaToka commented 1 year ago

Thank you for quick response.

I am trying to adapt the code block you referenced to my baseline, but I have some curiosity with a code as i am not familiary with gradio.. :(

before utilizing 'http_bot' to pass inputs by request, do we need to use "add_text" to pass inputs(text + image) to the state?

or just passing pload['images'] = rgb_image and pload['prompt']=prompt might work?

for clarity, i post the code for former one that WIP.

        headers = {"User-Agent": "LLaVA Client"}

        rgb = Image.open(observations['rgb']).convert('RGB')
        depth = Image.open(observations['depth']).convert('L')
        instruction = self.inst2step(observations['instruction'])

        text = self.prompt_generator(instruction)

        text = text[:1536]  # Hard cut-off
        if image is not None:
            text = text[:1200]  # Hard cut-off for images
            if '<image>' not in text:
                # text = '<Image><image></Image>' + text
                text = text + '\n<image>'
            text = (text, image, 'Resize')
            if len(self.state.get_images(return_pil=True)) > 0:
                self.state = default_conversation.copy()
        self.state.append_message(self.state.roles[0], text)
        self.state.append_message(self.state.roles[1], None)
        self.state.skip_next = False

        prompt = self.state.get_prompt()

        all_images = self.state.get_images(return_pil=True)
        all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]

        for image, hash in zip(all_images, all_image_hash):
            t = datetime.datetime.now()
            filename = os.path.join(
                LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
            if not os.path.isfile(filename):
                os.makedirs(os.path.dirname(filename), exist_ok=True)
                image.save(filename)

        # Make requests
        pload = {
            "model": self.model_name,
            "prompt": prompt,
            "temperature": float(0.2),
            "top_p": float(top_p),
            "max_new_tokens": min(int(1200), 1536),
            "stop": self.state.sep if self.state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else self.state.sep2,
            "images": f'List of {len(self.state.get_images())} images: {all_image_hash}',

        }

        pload['images'] = state.get_images()

        self.self.state.messages[-1][-1] = "▌"
        yield (self.state, self.state.to_gradio_chatbot())

          try:
              response = requests.post(self.worker_addr + "/worker_generate_stream", headers=headers, json=pload, stream=True, timeout=10)
              for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\\0"):
                  if chunk:
                      data = json.loads(chunk.decode())
                      if data["error_code"] == 0:
                          output = data["text"][len(prompt):].strip()
                          res = self.process_output(output)
                          return res['action']
                      else:
                          return f"Error: {data['text']} (error_code: {data['error_code']})"
          except requests.exceptions.RequestException as e:
              return "Server error."
haotian-liu commented 1 year ago

You can generate the prompt as usual (no need to follow the interactive gradio logic). You may find reference here.

Then, make sure that everything in pload is ready, and submit the http request. Let the model_worker do the job, and it will return the result in output.

TikaToka commented 1 year ago

Found out that there's a gpu error, look like same error with #464, and I posted detailed in error information there.

How can I handle this?

2023-09-30 12:10:42 | INFO | stdout | Ca ught Unknown Error CUDA error: device-si de assert triggered

current code ''' def load_image(self, image_file): import base64 from io import BytesIO image = image_file image = image.resize((336, 336)) buffered = BytesIO() image.save(buffered, format="PNG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() return img_b64_str

    qs = self.get_prompt(instruction, history)
    qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

    state = states[j]
    state.append_message(state.roles[0], qs)
    state.append_message(state.roles[1], None)

    prompt = state.get_prompt()

    # pload['images'] = state.get_images()

    # state.messages[-1][-1] = "▌"

    # Make requests

    image = self.load_image(ToPILImage(mode='RGB')(rgb))

    pload = {
        "model": self.model_name,
        "prompt": prompt,
        "temperature": float(0.2),
        "top_p": float(0.7),
        "max_new_tokens": 1536,
        "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
        "images": image,
    }

    pload['images'] = state.get_images()

    try:
        # Stream output
        response = requests.post(self.worker_addr + "/worker_generate_stream",
                                 headers=self.headers, json=pload, stream=True, timeout=100)
        for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
            if chunk:
                data = json.loads(chunk.decode())
                if data["error_code"] == 0:
                    output = data["text"][len(prompt):].strip()
                    print(output, '@@@@@@@@@@@@@@@@@@@@@@@@')
                    output = self.process_output(output)
                    action = output['Action']
                    if action == 'STOP':
                        action = 0
                    else:
                        img_idxes = action.split(' ')[1]
                        action = 1
                        distances = 0.25  # TODO mLLM이 예측할 수 있을 것인 가?
                    state.messages[-1][-1] = output + "▌"
                    history[j] = {
                        "action": action, "thought": output['Thought'], "step": output['Step']}
                else:
                    output = data["text"] + \
                        f" (error_code: {data['error_code']})"
                    print(output)

                time.sleep(0.03)
    except requests.exceptions.RequestException as e:
        print(e)

    cand_actions.append(action)

'''