run-llama / llama_index

LlamaIndex is a data framework for your LLM applications
https://docs.llamaindex.ai
MIT License
33.25k stars 4.65k forks source link

[Question]: How to fine tune text-to-SQL for claude model? #13710

Open Ninlawat-Puhu opened 1 month ago

Ninlawat-Puhu commented 1 month ago

Question Validation

Question

I try to search the information about fine tune text to SQL for claude model in llamaindex but I I don't found that.

Could you please share information or example how to do that? (e.g prompt, implementation,.. ) because I use claude-3-haiku-20240307 with llamanindex. They usually gen invalid SQL giving to me.

dosubot[bot] commented 1 month ago

To fine-tune text-to-SQL for the Claude model using LlamaIndex, specifically with the claude-3-haiku-20240307 version, follow these steps:

  1. Install Required Packages:

    %pip install llama-index-llms-gradient
    %pip install llama-index-finetuning
  2. Import Necessary Libraries:

    import os
    from llama_index.llms.gradient import GradientBaseModelLLM
    from llama_index.finetuning import GradientFinetuneEngine
  3. Set Environment Variables:

    os.environ["GRADIENT_ACCESS_TOKEN"] = os.getenv("GRADIENT_API_KEY")
    os.environ["GRADIENT_WORKSPACE_ID"] = ""
  4. Prepare Data: Load and save the dataset in the required format.

    from datasets import load_dataset
    from pathlib import Path
    import json
    
    def load_jsonl(data_dir):
       data_path = Path(data_dir).as_posix()
       data = load_dataset("json", data_files=data_path)
       return data
    
    def save_jsonl(data_dicts, out_path):
       with open(out_path, "w") as fp:
           for data_dict in data_dicts:
               fp.write(json.dumps(data_dict) + "\n")
    
    def load_data_sql(data_dir: str = "data_sql"):
       dataset = load_dataset("b-mc2/sql-create-context")
       dataset_splits = {"train": dataset["train"]}
       out_path = Path(data_dir)
       out_path.parent.mkdir(parents=True, exist_ok=True)
       for key, ds in dataset_splits.items():
           with open(out_path, "w") as f:
               for item in ds:
                   newitem = {
                       "input": item["question"],
                       "context": item["context"],
                       "output": item["answer"],
                   }
                   f.write(json.dumps(newitem) + "\n")
    
    load_data_sql(data_dir="data_sql")
  5. Split Data into Training/Validation Sets:

    from math import ceil
    
    def get_train_val_splits(data_dir: str = "data_sql", val_ratio: float = 0.1, seed: int = 42, shuffle: bool = True):
       data = load_jsonl(data_dir)
       num_samples = len(data["train"])
       val_set_size = ceil(val_ratio * num_samples)
       train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=shuffle, seed=seed)
       return train_val["train"].shuffle(), train_val["test"].shuffle()
    
    raw_train_data, raw_val_data = get_train_val_splits(data_dir="data_sql")
    save_jsonl(raw_train_data, "train_data_raw.jsonl")
    save_jsonl(raw_val_data, "val_data_raw.jsonl")
  6. Map Training Data to Prompts:

    text_to_sql_tmpl_str = """<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n{response}</s>"""
    text_to_sql_inference_tmpl_str = """<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n"""
    
    def _generate_prompt_sql(input, context, dialect="sqlite", output=""):
       system_message = """You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question."""
       user_message = f"""### Dialect:\n{dialect}\n### Input:\n{input}\n### Context:\n{context}\n### Response:\n"""
       if output:
           return text_to_sql_tmpl_str.format(system_message=system_message, user_message=user_message, response=output)
       else:
           return text_to_sql_inference_tmpl_str.format(system_message=system_message, user_message=user_message)
    
    def generate_prompt(data_point):
       full_prompt = _generate_prompt_sql(data_point["input"], data_point["context"], dialect="sqlite", output=data_point["output"])
       return {"inputs": full_prompt}
    
    train_data = [{"inputs": d["inputs"] for d in raw_train_data.map(generate_prompt)}]
    save_jsonl(train_data, "train_data.jsonl")
    val_data = [{"inputs": d["inputs"] for d in raw_val_data.map(generate_prompt)}]
    save_jsonl(val_data, "val_data.jsonl")
  7. Run Fine-tuning with Gradient:

    base_model_slug = "claude-3-haiku-20240307"
    base_llm = GradientBaseModelLLM(base_model_slug=base_model_slug, max_tokens=300)
    
    finetune_engine = GradientFinetuneEngine(
       base_model_slug=base_model_slug,
       name="text_to_sql",
       data_path="train_data.jsonl",
       verbose=True,
       max_steps=200,
       batch_size=4,
    )
    
    epochs = 1
    for i in range(epochs):
       print(f"** EPOCH {i} **")
       finetune_engine.finetune()
    
    ft_llm = finetune_engine.get_finetuned_model(max_tokens=300)
  8. Evaluation: Evaluate the fine-tuned model on validation data and a toy dataset.

    def get_text2sql_completion(llm, raw_datapoint):
       text2sql_tmpl_str = _generate_prompt_sql(raw_datapoint["input"], raw_datapoint["context"], dialect="sqlite", output=None)
       response = llm.complete(text2sql_tmpl_str)
       return str(response)
    
    test_datapoint = raw_val_data[2]
    get_text2sql_completion(ft_llm, test_datapoint)

This process outlines how to fine-tune the claude-3-haiku-20240307 model for text-to-SQL tasks using LlamaIndex and Gradient. For more detailed guidance, refer to the LlamaIndex documentation on fine-tuning [1][2][3].

#### Details **About Dosu** This response is meant to be useful and save you time. It is **not meant to be a precise solution**, but rather a starting point for your own research. Help me be more useful! Please leave a 👍 if this is helpful and 👎 if it is irrelevant.

To continue the conversation, mention @dosu.