Pints-AI / 1.5-Pints

A compact LLM pretrained in 9 days by using high quality data
MIT License
267 stars 21 forks source link

Training Loss Fluctuations with 0.12-Pint on Expository-Prose-V1 Dataset #10

Open tmylla opened 1 month ago

tmylla commented 1 month ago

image

Description: During training on the Expository-Prose-V1 dataset using the provided parameters, I observed unexpected fluctuations in the loss curve (see above image). The curve suggests instability, possibly due to data shuffling issues or parameter tuning. I followed the README instructions for data preparation from Hugging Face and used the following configuration:

fabric run \
--accelerator=cuda \
--devices=8 \
pretrain/main.py \
--data_dir data/ \
--gpus 8 \
--global_batch_size 2048 \
--learning_rate 4e-4 \
--micro_batch_size 64 \
--max_step 24060 \
--warmup_steps 2000 \
--weight_decay 0.1 \
--beta1 0.9 \
--beta2 0.95 \
--grad_clip 1.0 \
--min_lr 4e-5 \
--model_name 0.12-Pint \
--wandb_name seed-42-2048 \
--wandb_project 0.12-pint \
--tokenizer_dir tokenizer/pints

Question: Could this issue be caused by data shuffling inconsistencies, or might there be other factors involved? Have you encountered similar issues, and do you have any recommended solutions?

tmylla commented 1 month ago

The issue was indeed caused by the data not being shuffled properly. I wrote the following script (shuffle.py) to shuffle the pre-training data:

import os
import random
from pathlib import Path
from typing import Optional, List, TypedDict
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm

def read_file(filepath: Path, parquet_columns: Optional[List[str]] = None):
    contents = pq.read_table(filepath, columns=parquet_columns)
    return contents

def read_and_shuffle_all_files(directory: str) -> pa.Table:
    all_tables = []
    file_list = [f for f in os.listdir(directory) if f.endswith('.parquet')]

    # Iterate over all files in the directory with a progress bar
    for filename in tqdm(file_list, desc="Reading and concatenating files", unit="file"):
        file_path = Path(directory) / filename
        file_contents = read_file(file_path)  # Read the entire file
        all_tables.append(file_contents)

    print("Concatenate all the tables into one ...")
    combined_table = pa.concat_tables(all_tables)

    print("Convert to pandas for shuffling ...")
    df = combined_table.to_pandas()

    print("Shuffle the DataFrame ...")
    df = df.sample(frac=1).reset_index(drop=True)

    print("Convert back to Arrow Table ...")
    shuffled_table = pa.Table.from_pandas(df)

    return shuffled_table

def save_to_parquet(table: pa.Table, batch_size: int, output_dir: str):
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)

    print("Convert table to pandas DataFrame for batching ...")
    df = table.to_pandas()

    # Split the data into batches and save each batch as a separate parquet file with a progress bar
    for i in tqdm(range(0, len(df), batch_size), desc="Saving shuffled data", unit="batch"):
        batch_df = df.iloc[i:i + batch_size]
        batch_table = pa.Table.from_pandas(batch_df)
        output_file = output_dir / f'shuffled_part_{i//batch_size}.parquet'
        pq.write_table(batch_table, output_file)

# Custom variables
directory_path = 'your input directory path'  # Replace with your input directory path
output_directory = 'your output directory path'  # Replace with your output directory path
batch_size = 10000  # You can adjust batch size as needed

# Read, shuffle, and save the shuffled content
shuffled_table = read_and_shuffle_all_files(directory_path)
save_to_parquet(shuffled_table, batch_size, output_directory)

One can resolve this issue by running the script before executing the 'Prepare the dataset' step in the README instructions. This ensures the data is properly shuffled, avoiding the instability observed in the loss curve.

calvintwr commented 1 month ago

oh thanks for this. @tmylla Can you put up a pull request and I will review it!