Open dayuyang1999 opened 3 days ago
import json
def split_dict_equally(input_dict, chunks=8):
# Splits dict by keys, returns a list of dictionaries
# Using list comprehension + dict.items() + slicing
n = len(input_dict)
return [dict(list(input_dict.items())[i * n // chunks:(i + 1) * n // chunks]) for i in range(chunks)]
# Example dictionary
data = {i: i * i for i in range(100)} # Sample dictionary
# Split the dictionary into 8 sub-dictionaries
sub_dictionaries = split_dict_equally(data, 8)
# Save each sub-dictionary to a separate JSON file
for index, sub_dict in enumerate(sub_dictionaries):
with open(f'dictionary_part_{index+1}.json', 'w') as file:
json.dump(sub_dict, file, indent=4)
print("Dictionaries have been split and saved as JSON files.")
num=1
CUDA_VISIBLE_DIVICES=$num-1 torchrun --nproc_per_node 1 /home/jovyan/test_inference/llama3-main/inference_table_description.py \
--ckpt_dir /home/jovyan/model_weights/meta-llama/Meta-Llama-3-8B-Instruct/original \
--tokenizer_path /home/jovyan/model_weights/meta-llama/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
--max_seq_len 2048 \
--max_batch_size 24 \
--temperature 0.1 \
--top_p 0.9 \
--input_file "/home/jovyan/project/data/dictionary_part_${num}.json" \
--output_file "/home/jovyan/project/data/wikitables_descriptions_${num}.json"
#!/bin/bash
# Manually set the GPU number, assuming 0-indexed (GPU number - 1)
num=1
gpu_index=$(($num - 1)) # Calculates the 0-indexed GPU number
# Setting the specific GPU to be visible to the torchrun command
export CUDA_VISIBLE_DEVICES=$gpu_index
# Run the PyTorch distributed script with specified parameters
torchrun --nproc_per_node 1 /home/jovyan/test_inference/llama3-main/inference_table_description.py \
--ckpt_dir /home/jovyan/model_weights/meta-llama/Meta-Llama-3-8B-Instruct/original \
--tokenizer_path /home/jovyan/model_weights/meta-llama/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
--max_seq_len 2048 \
--max_batch_size 24 \
--temperature 0.1 \
--top_p 0.9 \
--input_file "/home/jovyan/project/data/dictionary_part_${num}.json" \
--output_file "/home/jovyan/project/data/wikitables_descriptions_${num}.json"
import json
from typing import List, Optional
import fire
from llama import Dialog, Llama
from tqdm import tqdm
import torch
SYSTEM_PROMPT = """
You are a senior data scientist that responsible to write a description for a table data in json format.
An example of a good description is: This table contains data of Meridian Bank auto loan customers who are undergoing through the risk-reduction process. Data is sourced from various Meridian Loan Services and enterprise tables for building customer profile and used for sending out the risk-reduction related communication to customers. This dataset is owned by the Delta team. This table holds the customer correspondence dates.
Your mission is to read the tabular datapoint(provided in json format) and output a natural language description that contains around 50 words. I provide you an example for reference.
The description must contains the following aspects:
1. Overall description of data Content
2. Purpose of creating this table
The description may mention the following aspects:
1. Data Subject
2. Data Usage
3. Data Source
4. Data Ownership
Your description should follow the instructions:
1. do not reveal any specific cell data
2. you can reasonably infer some aspects
Your output must always only contains a text description only, do not explain yourself or output anything else. Be thoroughgoing!
"""
def load_tables(file_path: str) -> dict:
with open(file_path, 'r') as f:
return json.load(f)
def save_results(results: dict, file_path: str):
with open(file_path, 'w') as f:
json.dump(results, f, indent=2)
def create_dialog(table_data: dict) -> Dialog:
return [
{
"role": "system",
"content": SYSTEM_PROMPT
},
{
"role": "user",
"content": json.dumps(table_data)
}
]
def batch_dialogs(dialogs: List[Dialog], batch_size: int):
for i in range(0, len(dialogs), batch_size):
yield dialogs[i:i + batch_size]
def process_batch(generator, batch, temperature, top_p, max_gen_len):
try:
return generator.chat_completion(
batch,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
except: # can be any error
return None
def main(
ckpt_dir: str,
tokenizer_path: str,
input_file: str,
output_file: str,
temperature: float = 0.6,
top_p: float = 0.9,
max_seq_len: int = 512,
max_batch_size: int = 4,
max_gen_len: Optional[int] = None,
):
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
tables = load_tables(input_file)
all_dialogs = [(table_id, create_dialog(table_data)) for table_id, table_data in tables.items()]
results = {}
skipped_batches = []
for batch_idx, batch in enumerate(tqdm(list(batch_dialogs([d for _, d in all_dialogs], max_batch_size)), desc="Processing batches")):
batch_results = process_batch(generator, batch, temperature, top_p, max_gen_len)
if batch_results is None:
print(f"Skipping batch {batch_idx} due to error")
skipped_batches.append(batch_idx)
continue
for (table_id, _), result in zip(all_dialogs[batch_idx * max_batch_size:(batch_idx + 1) * max_batch_size], batch_results):
results[table_id] = result['generation']['content']
save_results(results, output_file)
print(f"All tables processed. Results saved to {output_file}")
if skipped_batches:
print(f"Warning: The following batch indices were skipped due to errors: {skipped_batches}")
if __name__ == "__main__":
fire.Fire(main)
Processing batches: 0%|▎ | 4/1216 [05:47<23:56:17, 71.10s/it]^CW0704 02:09:29.124466 139926047135552 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 21673 closing signal SIGTERM Traceback (most recent call last): File "/opt/conda/envs/llm/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 733, in run result = self._invoke_run(role) File "/opt/conda/envs/llm/lib/python3.9/site-packages/torch/distributed/elastic/agent/server/api.py", line 876, in _invoke_run time.sleep(monitor_interval) File "/opt/conda/envs/llm/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 76, in _terminate_process_handler raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval) torch.distributed.elastic.multiprocessing.api.SignalException: Process 21560 got signal: 2