llava-rlhf / LLaVA-RLHF

Aligning LMMs with Factually Augmented RLHF
https://llava-rlhf.github.io/
GNU General Public License v3.0
315 stars 21 forks source link

error about call model #3

Closed LiqiangJing closed 1 year ago

LiqiangJing commented 1 year ago
image

I used the model_vqa.py but I met the mistake for LLaVA-RLHF

LiqiangJing commented 1 year ago
disable_torch_init()
load_bf16 = True
model_path = "/home/xxx/LLaVA-RLHF/llava/eval/checkpoints/sft_model"
lora_path = "/home/xxx/LLaVA-RLHF/llava/eval/checkpoints/rlhf_lora_adapter_model"
model_name = "LLaVA-RLHF-13b-v1.5-336"
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name,
                                                                       load_bf16=load_bf16)

model = PeftModel.from_pretrained(
    model,
    lora_path,
)

questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] ###
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
answers_file = os.path.expanduser(args.answers_file)

os.makedirs(os.path.dirname(answers_file), exist_ok=True)
ans_file = open(answers_file, "w")
for line in tqdm(questions):
    idx = line["id"]
    image_file = line["image"]
    qs = line["instruction"]
    cur_prompt = qs
    if model.config.mm_use_im_start_end:
        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
    else:
        qs = DEFAULT_IMAGE_TOKEN + '\n' + qs

    conv = conv_templates[args.conv_mode].copy()
    conv.append_message(conv.roles[0], qs)
    conv.append_message(conv.roles[1], None)
    prompt = conv.get_prompt()

    input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()

    image = Image.open(os.path.join(args.image_folder, image_file))
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].to(dtype=torch.bfloat16)

    stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    model = model.merge_and_unload()

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor.unsqueeze(0).half().cuda(),
            do_sample=True,
            temperature=args.temperature,
            top_p=args.top_p,
            num_beams=args.num_beams,
            # no_repeat_ngram_size=3,
            max_new_tokens=1024,
            use_cache=True)

    input_token_len = input_ids.shape[1]
    n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
    if n_diff_input_output > 0:
        print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
    outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
    outputs = outputs.strip()
    if outputs.endswith(stop_str):
        outputs = outputs[:-len(stop_str)]
    outputs = outputs.strip()

    ans_id = shortuuid.uuid()
    ans_file.write(json.dumps({"question_id": idx,
                               "prompt": cur_prompt,
                               "text": outputs,
                               "answer_id": ans_id,
                               "model_id": model_name,
                               "metadata": {}}) + "\n")
    ans_file.flush()
ans_file.close()
Edward-Sun commented 1 year ago

Hi Liqiang,

Please let me know if this piece of code for setup the LLaVA serving works for you.

model_worker.py for supporting loading models with LoRA and bf16: Gist

Reference for setting up the llava demo

yuvalkirstain commented 1 year ago

Not working for me. It will be helpful if you create a branch from the llava repo that allows this and specify exact instructions on running inference with the model.

Edward-Sun commented 1 year ago

@yuvalkirstain @LiqiangJing

Hi, could you check if your PEFT version >= 0.4.0? If the error is about PEFT, I guess it's a PEFT version issue.

yuvalkirstain commented 1 year ago
2023-10-05 10:20:32 | ERROR | stderr | Traceback (most recent call last):
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
2023-10-05 10:20:32 | ERROR | stderr |     self.run()
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/threading.py", line 953, in run
2023-10-05 10:20:32 | ERROR | stderr |     self._target(*self._args, **self._kwargs)
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/peft/peft_model.py", line 977, in generate
2023-10-05 10:20:32 | ERROR | stderr |     outputs = self.base_model.generate(**kwargs)
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
2023-10-05 10:20:32 | ERROR | stderr |     return func(*args, **kwargs)
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 1648, in generate
2023-10-05 10:20:32 | ERROR | stderr |     return self.sample(
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py", line 2730, in sample
2023-10-05 10:20:32 | ERROR | stderr |     outputs = self(
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
2023-10-05 10:20:32 | ERROR | stderr |     return forward_call(*args, **kwargs)
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
2023-10-05 10:20:32 | ERROR | stderr |     output = old_forward(*args, **kwargs)
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/LLaVA/llava/model/language_model/llava_llama.py", line 78, in forward
2023-10-05 10:20:32 | ERROR | stderr |     outputs = self.model(
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
2023-10-05 10:20:32 | ERROR | stderr |     return forward_call(*args, **kwargs)
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 668, in forward
2023-10-05 10:20:32 | ERROR | stderr |     attention_mask = self._prepare_decoder_attention_mask(
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 596, in _prepare_decoder_attention_mask
2023-10-05 10:20:32 | ERROR | stderr |     combined_attention_mask = _make_causal_mask(
2023-10-05 10:20:32 | ERROR | stderr |   File "/fsx-multigen/yuvalkirstain/miniconda/envs/llava/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 52, in _make_causal_mask
2023-10-05 10:20:32 | ERROR | stderr |     mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
2023-10-05 10:20:32 | ERROR | stderr | RuntimeError: CUDA error: device-side assert triggered
2023-10-05 10:20:32 | ERROR | stderr | CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
2023-10-05 10:20:32 | ERROR | stderr | For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
2023-10-05 10:20:32 | ERROR | stderr | Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The changes that you suggested seem very simple, so if can create a branch from the llava repo that allows this and specify exact instructions on running inference with the model it will be helpful.

Edward-Sun commented 1 year ago

Hi @yuvalkirstain , Could you please check if you can load the original llava checkpoint without any error? The error message you show here does not seem to be specificc to the llava-rlhf checkpoint.

Edward-Sun commented 1 year ago

Hi, to simplify the modifications, we just added a minimal example to launch the llava-rlhf demo. Please let me know if this works for you.