tabtoyou / KoLLaVA

KoLLaVA: Korean Large Language-and-Vision Assistant (feat.LLaVA)
Apache License 2.0
275 stars 30 forks source link

inference test시 에러 발생 관련 문의 #20

Closed daje0601 closed 8 months ago

daje0601 commented 8 months ago

[개요] 저는 runpods에서 a100 3장을 이용하여 fine-tuning을 하였습니다. 제가 사용한 가상환경은 pytorch:2.0.1-py3.10-cuda11.8.0입니다. 설치는 readme에 나와 있는 것처럼 아래와 같이 설치를 진행하였습니다.

pip install --upgrade pip 
pip install -e .
pip install -e ".[train]"
pip install flash-attn --no-build-isolation

위와 같은 설치 순으로 설치하고 정상적으로 fine-tuning까지 마쳤습니다.

[이슈] inference를 하기 위해 llava.eval.run_llava를 이용하여 inference를 실행하면 아래와 같이 'LlavaLlamaForCausalLM'를 import 할 수 없다고 합니다. transformers를 설치하게 되면 llava라는 게 이미 있으니 다른 이름을 사용하라고 하거든요... 제가 무슨 실수를 하고 있는지 도저히 모르겠어서 질문드려요 ㅠㅠ

import 에러가 발생되는 위치는 run_llava.py -> from llava.model.builder import load_pretrained_model -> from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, LlavaLlamaForCausalLM 여기에 있는 LlavaLalamaForCausalLM에서 발생됩니다.

image
daje0601 commented 8 months ago

이슈해결했어요~

해결전 코드

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, LlavaLlamaForCausalLM

해결후 코드

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
from llava.model.language_model.llava_llama import LlavaLlamaForCausalLM