lm-sys / FastChat

An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and Chatbot Arena.
Apache License 2.0
36.35k stars 4.47k forks source link

When PEFT weights are located elsewhere than the FastChat root folder, loading PEFT checkpoint fails. #1853

Open srikant86panda opened 1 year ago

srikant86panda commented 1 year ago

Line https://github.com/lm-sys/FastChat/blob/main/fastchat/model/model_adapter.py#L325 receives model name instead of the complete or relative path when loading PEFT weight using PEFT adapter with gradio_web_server. Code:

def get_default_conv_template(self, model_path: str) -> Conversation:
        """Uses the conv template of the base model"""
        from peft import PeftConfig, PeftModel

        config = PeftConfig.from_pretrained(model_path)## gets just the model name and not the complete or relative path

Output error:

(fast_chat) [srikapan@alpha-a100-2 FastChat]$ python3 -m fastchat.serve.gradio_web_server
2023-07-04 11:34:26 | INFO | gradio_web_server | args: Namespace(host='0.0.0.0', port=None, share=False, controller_url='http://localhost:21001', concurrency_count=10, model_list_mode='once', moderate=False, add_chatgpt=False, add_claude=False, add_palm=False, gradio_auth_path=None)
2023-07-04 11:34:26 | INFO | gradio_web_server | Models: ['starcoderbase_ns_peft']
2023-07-04 11:34:27 | INFO | stdout | Running on local URL:  http://0.0.0.0:7860
2023-07-04 11:34:27 | INFO | stdout | 
2023-07-04 11:34:27 | INFO | stdout | To create a public link, set `share=True` in `launch()`.
2023-07-04 11:34:30 | INFO | gradio_web_server | load_demo. ip: 127.0.0.1. params: {}
2023-07-04 11:34:30 | INFO | httpx | HTTP Request: POST http://localhost:7860/api/predict "HTTP/1.1 200 OK"
2023-07-04 11:34:30 | INFO | httpx | HTTP Request: POST http://localhost:7860/reset "HTTP/1.1 200 OK"
2023-07-04 11:34:34 | INFO | gradio_web_server | add_text. ip: 127.0.0.1. len: 2
2023-07-04 11:34:34 | ERROR | stderr | Traceback (most recent call last):
2023-07-04 11:34:34 | ERROR | stderr |   File "/home/srikapan/anaconda3/envs/fast_chat/lib/python3.9/site-packages/huggingface_hub/utils/_errors.py", line 259, in hf_raise_for_status
2023-07-04 11:34:34 | ERROR | stderr |     response.raise_for_status()
2023-07-04 11:34:34 | ERROR | stderr |   File "/home/srikapan/anaconda3/envs/fast_chat/lib/python3.9/site-packages/requests/models.py", line 1021, in raise_for_status
2023-07-04 11:34:34 | ERROR | stderr |     raise HTTPError(http_error_msg, response=self)
2023-07-04 11:34:34 | ERROR | stderr | requests.exceptions.HTTPError: 404 Client Error: Not Found for url: https://huggingface.co/starcoderbase_ns_peft/resolve/main/adapter_config.json

Model loading fails as a result, which causes a Web UI error.

merrymercy commented 1 year ago

cc @fozziethebeat

fozziethebeat commented 1 year ago

Oof, i haven't tried the gradio web server, does that load up a worker automatically within itself?

fozziethebeat commented 1 year ago

Doing some quick code reading, I think someone more familiar with the gradio server will need to suggest a fix. I'm not entirely clear how this works for local models in general.

Key code snippets I see are:


def get_model_list(controller_url, add_chatgpt, add_claude, add_palm):
    ret = requests.post(controller_url + "/refresh_all_workers")
    assert ret.status_code == 200
    ret = requests.post(controller_url + "/list_models")
    models = ret.json()["models"]

    # Add API providers
    if add_chatgpt:
        models += ["gpt-3.5-turbo", "gpt-4"]
    if add_claude:
        models += ["claude-v1", "claude-instant-v1"]
    if add_palm:
        models += ["palm-2"]

    priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
    models.sort(key=lambda x: priority.get(x, x))
    logger.info(f"Models: {models}")
    return models

And then


def add_text(state, model_selector, text, request: gr.Request):
    ip = request.client.host
    logger.info(f"add_text. ip: {ip}. len: {len(text)}")

    if state is None:
        state = State(model_selector)

And finally

class State:
    def __init__(self, model_name):
        self.conv = get_conversation_template(model_name)
        self.conv_id = uuid.uuid4().hex
        self.skip_next = False
        self.model_name = model_name

So my guess is that gradio is fetching only model names from the worker/controller and then using that to populate the drop down selector.