dayuyang1999 / random_code

For random code access
0 stars 0 forks source link

MP code #9

Open dayuyang1999 opened 3 weeks ago

dayuyang1999 commented 3 weeks ago
import json
from typing import List, Optional
import fire
from llama import Dialog, Llama
from tqdm import tqdm
import torch
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

SYSTEM_PROMPT = """
You are a senior data scientist that responsible to write a description for a table data in json format.

An example of a good description is: This table contains data of Meridian Bank auto loan customers who are undergoing through the risk-reduction process. Data is sourced from various Meridian Loan Services and enterprise tables for building customer profile and used for sending out the risk-reduction related communication to customers. This dataset is owned by the Delta team. This table holds the customer correspondence dates.

Your mission is to read the tabular datapoint(provided in json format) and output a natural language description that contains around 50 words. I provide you an example for reference.

The description must contains the following aspects:
1. Overall description of data Content
2. Purpose of creating this table

The description may mention the following aspects:
1. Data Subject
2. Data Usage
3. Data Source
4. Data Ownership

Your description should follow the instructions:
1. do not reveal any specific cell data
2. you can reasonably infer some aspects

Your output must always only contains a text description only, do not explain yourself or output anything else. Be thoroughgoing!
"""

def load_tables(file_path: str) -> dict:
    with open(file_path, 'r') as f:
        return json.load(f)

def save_results(results: dict, file_path: str):
    with open(file_path, 'w') as f:
        json.dump(results, f, indent=2)

def create_dialog(table_data: dict) -> Dialog:
    return [
        {
            "role": "system",
            "content": SYSTEM_PROMPT
        },
        {
            "role": "user",
            "content": json.dumps(table_data)
        }
    ]

def batch_dialogs(dialogs: List[Dialog], batch_size: int):
    for i in range(0, len(dialogs), batch_size):
        yield dialogs[i:i + batch_size]

def process_batch(generator, batch, temperature, top_p, max_gen_len):
    try:
        return generator.chat_completion(
            batch,
            max_gen_len=max_gen_len,
            temperature=temperature,
            top_p=top_p,
        )
    except Exception as e:
        logging.error(f"Error processing batch: {str(e)}")
        return None

def main(
    ckpt_dir: str,
    tokenizer_path: str,
    input_file: str,
    output_file: str,
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_seq_len: int = 512,
    max_batch_size: int = 4,
    max_gen_len: Optional[int] = None,
):
    generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
    )

    tables = load_tables(input_file)

    all_dialogs = [(table_id, create_dialog(table_data)) for table_id, table_data in tables.items()]

    results = {}
    skipped_tables = []

    progress_bar = tqdm(list(batch_dialogs([d for _, d in all_dialogs], max_batch_size)), desc="Processing batches")

    try:
        for batch_idx, batch in enumerate(progress_bar):
            batch_results = process_batch(generator, batch, temperature, top_p, max_gen_len)

            if batch_results is None:
                logging.warning(f"Skipping batch {batch_idx} due to error")
                skipped_tables.extend([all_dialogs[i + batch_idx * max_batch_size][0] for i in range(len(batch))])
                continue

            for (table_id, _), result in zip(all_dialogs[batch_idx * max_batch_size:(batch_idx + 1) * max_batch_size], batch_results):
                results[table_id] = result['generation']['content']

            if batch_idx % 10 == 0:  # Save results every 10 batches
                save_results(results, output_file)
                logging.info(f"Intermediate results saved to {output_file}")

    except KeyboardInterrupt:
        logging.info("Process interrupted by user. Saving current results...")

    finally:
        save_results(results, output_file)
        logging.info(f"All processed results saved to {output_file}")

        if skipped_tables:
            logging.warning(f"The following tables were skipped due to errors: {skipped_tables}")
            with open(f"{output_file}_skipped.json", 'w') as f:
                json.dump(skipped_tables, f, indent=2)
            logging.info(f"List of skipped tables saved to {output_file}_skipped.json")

if __name__ == "__main__":
    fire.Fire(main)
dayuyang1999 commented 3 weeks ago
#!/bin/bash

# cluster 1
num=1
gpu_index=$(($num - 1))  # Calculates the 0-indexed GPU number
port_num=$(($num + 29500))
# Setting the specific GPU to be visible to the torchrun command
# Run the PyTorch distributed script with specified parameters
CUDA_VISIBLE_DEVICES=$gpu_index torchrun --nproc_per_node 1 --master_port $port_num /home/jovyan/test_inference/llama3-main/inference_MP_table_description.py \
    --ckpt_dir /home/jovyan/model_weights/meta-llama/Meta-Llama-3-8B-Instruct/original \
    --tokenizer_path /home/jovyan/model_weights/meta-llama/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    --max_seq_len 2048 \
    --max_batch_size 24 \
    --temperature 0.1 \
    --top_p 0.9 \
    --input_file "/home/jovyan/project/description_data/dictionary_part_${num}.json" \
    --output_file "/home/jovyan/project/description_generation/wikitables_descriptions_${num}.json"n
dayuyang1999 commented 3 weeks ago
#!/bin/bash

# cluster 1
num=1
gpu_index=$(($num - 1))  # Calculates the 0-indexed GPU number
port_num=$(($num + 29500))
# Setting the specific GPU to be visible to the torchrun command
# Run the PyTorch distributed script with specified parameters
CUDA_VISIBLE_DEVICES=$gpu_index torchrun --nproc_per_node 1 --master_port $port_num /home/jovyan/test_inference/llama3-main/inference_MP_table_description.py \
    --ckpt_dir /home/jovyan/model_weights/meta-llama/Meta-Llama-3-8B-Instruct/original \
    --tokenizer_path /home/jovyan/model_weights/meta-llama/Meta-Llama-3-8B-Instruct/original/tokenizer.model \
    --max_seq_len 2048 \
    --max_batch_size 24 \
    --temperature 0.1 \
    --top_p 0.9 \
    --input_file "/home/jovyan/project/description_data/dictionary_part_${num}.json" \
    --output_file "/home/jovyan/project/description_generation/wikitables_descriptions_${num}.json"n