Open Ninlawat-Puhu opened 1 month ago
To fine-tune text-to-SQL for the Claude model using LlamaIndex, specifically with the claude-3-haiku-20240307
version, follow these steps:
Install Required Packages:
%pip install llama-index-llms-gradient
%pip install llama-index-finetuning
Import Necessary Libraries:
import os
from llama_index.llms.gradient import GradientBaseModelLLM
from llama_index.finetuning import GradientFinetuneEngine
Set Environment Variables:
os.environ["GRADIENT_ACCESS_TOKEN"] = os.getenv("GRADIENT_API_KEY")
os.environ["GRADIENT_WORKSPACE_ID"] = ""
Prepare Data: Load and save the dataset in the required format.
from datasets import load_dataset
from pathlib import Path
import json
def load_jsonl(data_dir):
data_path = Path(data_dir).as_posix()
data = load_dataset("json", data_files=data_path)
return data
def save_jsonl(data_dicts, out_path):
with open(out_path, "w") as fp:
for data_dict in data_dicts:
fp.write(json.dumps(data_dict) + "\n")
def load_data_sql(data_dir: str = "data_sql"):
dataset = load_dataset("b-mc2/sql-create-context")
dataset_splits = {"train": dataset["train"]}
out_path = Path(data_dir)
out_path.parent.mkdir(parents=True, exist_ok=True)
for key, ds in dataset_splits.items():
with open(out_path, "w") as f:
for item in ds:
newitem = {
"input": item["question"],
"context": item["context"],
"output": item["answer"],
}
f.write(json.dumps(newitem) + "\n")
load_data_sql(data_dir="data_sql")
Split Data into Training/Validation Sets:
from math import ceil
def get_train_val_splits(data_dir: str = "data_sql", val_ratio: float = 0.1, seed: int = 42, shuffle: bool = True):
data = load_jsonl(data_dir)
num_samples = len(data["train"])
val_set_size = ceil(val_ratio * num_samples)
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=shuffle, seed=seed)
return train_val["train"].shuffle(), train_val["test"].shuffle()
raw_train_data, raw_val_data = get_train_val_splits(data_dir="data_sql")
save_jsonl(raw_train_data, "train_data_raw.jsonl")
save_jsonl(raw_val_data, "val_data_raw.jsonl")
Map Training Data to Prompts:
text_to_sql_tmpl_str = """<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n{response}</s>"""
text_to_sql_inference_tmpl_str = """<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n"""
def _generate_prompt_sql(input, context, dialect="sqlite", output=""):
system_message = """You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question."""
user_message = f"""### Dialect:\n{dialect}\n### Input:\n{input}\n### Context:\n{context}\n### Response:\n"""
if output:
return text_to_sql_tmpl_str.format(system_message=system_message, user_message=user_message, response=output)
else:
return text_to_sql_inference_tmpl_str.format(system_message=system_message, user_message=user_message)
def generate_prompt(data_point):
full_prompt = _generate_prompt_sql(data_point["input"], data_point["context"], dialect="sqlite", output=data_point["output"])
return {"inputs": full_prompt}
train_data = [{"inputs": d["inputs"] for d in raw_train_data.map(generate_prompt)}]
save_jsonl(train_data, "train_data.jsonl")
val_data = [{"inputs": d["inputs"] for d in raw_val_data.map(generate_prompt)}]
save_jsonl(val_data, "val_data.jsonl")
Run Fine-tuning with Gradient:
base_model_slug = "claude-3-haiku-20240307"
base_llm = GradientBaseModelLLM(base_model_slug=base_model_slug, max_tokens=300)
finetune_engine = GradientFinetuneEngine(
base_model_slug=base_model_slug,
name="text_to_sql",
data_path="train_data.jsonl",
verbose=True,
max_steps=200,
batch_size=4,
)
epochs = 1
for i in range(epochs):
print(f"** EPOCH {i} **")
finetune_engine.finetune()
ft_llm = finetune_engine.get_finetuned_model(max_tokens=300)
Evaluation: Evaluate the fine-tuned model on validation data and a toy dataset.
def get_text2sql_completion(llm, raw_datapoint):
text2sql_tmpl_str = _generate_prompt_sql(raw_datapoint["input"], raw_datapoint["context"], dialect="sqlite", output=None)
response = llm.complete(text2sql_tmpl_str)
return str(response)
test_datapoint = raw_val_data[2]
get_text2sql_completion(ft_llm, test_datapoint)
This process outlines how to fine-tune the claude-3-haiku-20240307
model for text-to-SQL tasks using LlamaIndex and Gradient. For more detailed guidance, refer to the LlamaIndex documentation on fine-tuning [1][2][3].
To continue the conversation, mention @dosu.
Question Validation
Question
I try to search the information about fine tune text to SQL for claude model in llamaindex but I I don't found that.
Could you please share information or example how to do that? (e.g prompt, implementation,.. ) because I use claude-3-haiku-20240307 with llamanindex. They usually gen invalid SQL giving to me.