huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.37k stars 27.09k forks source link

Gradual Slowdown of training on bigger batches #34685

Open sidharthg-couture opened 1 week ago

sidharthg-couture commented 1 week ago

System Info

Who can help?

@muellerzr @ArthurZucker

Information

Tasks

Reproduction

Issue

As you can see the training slows down for larger batches.

Experiment Details

Training Code

import traceback
from datetime import datetime
# from accelerate.logging import get_logger
# logging = get_logger(__name__, log_level="INFO")

from datasets import Dataset, load_dataset
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
enable_progress_bar()
# disable_progress_bar()

from peft import LoraConfig, get_peft_model
from peft import PeftConfig, PeftModel

from transformers import AutoModel, AutoTokenizer, AutoConfig, TrainerCallback, get_linear_schedule_with_warmup

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.evaluation import SentenceEvaluator
from sentence_transformers.losses import MultipleNegativesRankingLoss, GISTEmbedLoss
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import MultiDatasetBatchSamplers, SentenceTransformerTrainingArguments, BatchSamplers
from sentence_transformers.sampler import NoDuplicatesBatchSampler, ProportionalBatchSampler
from sentence_transformers import LoggingHandler, util
from transformers import BitsAndBytesConfig

import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import BatchSampler, ConcatDataset, SubsetRandomSampler

from typing import Any, Iterator
from collections import defaultdict
from itertools import accumulate, cycle

import pandas as pd
import numpy as np
import gc
import os
import random

RANDOM_STATE = 0

random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

#setting env variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_DATASETS_CACHE"] = "/data/hgfc_cache_new"
os.environ["HF_HOME"] = "/data/hgfc_cache_new"

TRIAL = False # control param for test runs

BASE_DIR = "/data/training_runs/stella_mnr_train_epoch3"

CACHE_DIR = "/data/hgfc_cache_new"

# if TRIAL: BASE_DIR = "/data/monil/stella_mnr_train_final"

if not os.path.exists(BASE_DIR): os.makedirs(BASE_DIR)

# Hyper Params
CONFIG = {
    "base_model_path": "/data/models/stella_checkpoint_2ep",

    "triplet_data_path": "/data/datasets/prompted_train_data_full_run/triplet_data",
    "duplet_data_path": "/data/datasets/prompted_train_data_full_run/duplet_data",
    "query_pair_path": "/data/datasets/prompted_train_data_full_run/query_pair_data",
    "product_descriptions_path": "/data/datasets/product_descriptions_with_keywords_14072024.pkl",
    "deepspeed_config_path": "/app/notebooks/monil/deepspeed.config",
    "keep_in_memory": True,

    "test_samples": 100 if TRIAL else 10000,
    "dev_samples": 100 if TRIAL else 5000,
    "batch_size": 32,
    "accumulation_step": 1,
    "eval_steps": 10 if TRIAL else 10000,
    "num_epochs":3,

    "output_dir":f"{BASE_DIR}/model_checkpoints_bf16",
    "logging_dir":f"{BASE_DIR}/train_logs/run_4x32_in_memory_fixed",
    "datasets_dir":f"{BASE_DIR}/val_test_datasets"
}

if not os.path.exists(CONFIG["logging_dir"]): os.makedirs(CONFIG["logging_dir"])
if not os.path.exists(CONFIG["output_dir"]): os.makedirs(CONFIG["output_dir"])
if not os.path.exists(CONFIG["datasets_dir"]): os.makedirs(CONFIG["datasets_dir"])

if not TRIAL:
    logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, filename=CONFIG["logging_dir"]+"/logs.txt")
else:
    logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)  

logging.info(CONFIG)

# custom evaluator class
class SearchEvaluator(SentenceEvaluator):
    """
    This class evaluates a SentenceTransformer model for the task of re-ranking.
    """
    def __init__(self, dev_dict, product_descriptions, num_sample_products = 50000):

        random.seed(RANDOM_STATE)
        torch.manual_seed(RANDOM_STATE)
        self.product_descriptions = product_descriptions

        self.all_products = list({value for values in dev_dict.values() for value in values["product_code"]}.union({value for values in dev_dict.values() for value in values["anti_product_code"]}))
        self.all_products = list(random.sample(self.all_products, min(len(self.all_products), num_sample_products)))

        self.dev_dict = {key:{"product_code":values["product_code"].intersection(set(self.all_products)),"anti_product_code":values["anti_product_code"].intersection(set(self.all_products))}  for key, values in dev_dict.items()}

        self.max_top_k = max(list({len(value["product_code"]) for value in self.dev_dict.values()}))
        self.csv_file = "SearchEvaluator_results.csv"
        print("Unique products", len(self.all_products))
        print("Max top k", self.max_top_k)

    def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:

        corpus_embeddings = model.encode([self.product_descriptions[code] for code in self.all_products], convert_to_tensor=True, show_progress_bar=False)
        queries = list(self.dev_dict.keys())

        query_embedding = model.encode(queries, convert_to_tensor=True, show_progress_bar=False)

#         print("Embeddings created")
        prec = 0
        valid_queries = 0
        for i, result in enumerate(util.semantic_search(query_embedding, corpus_embeddings, top_k=self.max_top_k)):
            orig_products = set(self.dev_dict[queries[i]]["product_code"])
            num_products = len(orig_products)
            if num_products>0:
                retrieved_products = {self.all_products[result[i]["corpus_id"]] for i in range(num_products)}
                prec+= (len(retrieved_products.intersection(orig_products))/num_products)
                valid_queries+=1

        prec = prec/valid_queries

        logging.info("precision on dev set: {}".format(steps, prec))

        del corpus_embeddings
        del query_embedding
        gc.collect()
        torch.cuda.empty_cache()

        return prec

def main():
    # Set the log level to INFO to get more information  

    print("entered main")

    logging.info(CONFIG)

    # loading data
    logging.info("Imports successful, loading data")

    product_descriptions = pd.read_pickle(CONFIG["product_descriptions_path"])
    logging.info("loaded {} unique product descriptions in catalogue".format(len(product_descriptions)))

    triplet_train_dataset = load_dataset("parquet", data_files=f"{CONFIG['triplet_data_path']}/train/*.parquet", cache_dir=CACHE_DIR, keep_in_memory = CONFIG["keep_in_memory"])
    logging.info(f"Triplet train dataset:  {triplet_train_dataset}")
    logging.info(triplet_train_dataset["train"][0])

    triplet_dev_dataset = load_dataset("parquet", data_files=f"{CONFIG['triplet_data_path']}/dev/*.parquet", cache_dir=CACHE_DIR, keep_in_memory = CONFIG["keep_in_memory"])
    logging.info(f"Triplet dev dataset:  {triplet_dev_dataset}")
    logging.info(triplet_dev_dataset["train"][0])

    grouped_dev_triplets = triplet_dev_dataset['train'].to_pandas().groupby("query").agg(set).reset_index()
    dev_dict = {row["query"]:{"product_code":row["selected_ids"], "anti_product_code":row["anti_selected_ids"]} for _,row in grouped_dev_triplets.iterrows()}
    logging.info("Triplet Data dev query count: {}".format(len(dev_dict)))
    logging.info(triplet_dev_dataset["train"][0])

    duplet_train_dataset = load_dataset("parquet", data_files=f"{CONFIG['duplet_data_path']}/train/*.parquet", cache_dir=CACHE_DIR, keep_in_memory = CONFIG["keep_in_memory"])
    logging.info(f"Duplet train dataset:  {duplet_train_dataset}")
    logging.info(duplet_train_dataset["train"][0])

    query_pair_train_dataset = load_dataset("parquet", data_files=f"{CONFIG['query_pair_path']}/train/*.parquet", cache_dir=CACHE_DIR, keep_in_memory = CONFIG["keep_in_memory"])
    logging.info(f"Query pair train dataset: {query_pair_train_dataset}")
    logging.info(query_pair_train_dataset["train"][0])

    final_train_dataset = {
          "product_triplets": triplet_train_dataset.select_columns(['query', 'positive', 'negative']).shuffle(seed=RANDOM_STATE).rename_column("query", "anchor")['train'],
          "product_duplets": duplet_train_dataset.select_columns(['query', 'positive']).shuffle(seed=RANDOM_STATE).rename_column("query", "anchor")['train'],
          "hinglish_duplets": query_pair_train_dataset.select_columns(['query', 'positive']).shuffle(seed=RANDOM_STATE).rename_column("query", "anchor")['train']

    }

    final_dev_dataset = {
          "product_triplets": triplet_dev_dataset.select_columns(['query', 'positive']).shuffle(seed=RANDOM_STATE).rename_column("query", "anchor")['train']
    }

    # model loading
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    logging.info("loading model for finetuning")
    device = f"cuda:{local_rank}"

    model = SentenceTransformer(CONFIG["base_model_path"], trust_remote_code=True, device=device)

    logging.info("Final Finetuning Model:\n{}".format(model))

    gc.collect()
    torch.cuda.empty_cache()

    model = model.to(f"cuda:{local_rank}")

    logging.info("triggering trainer")
    args = SentenceTransformerTrainingArguments(
        do_train=True,
        do_eval=True,
        # Required parameter:
        output_dir=CONFIG["output_dir"],
        overwrite_output_dir = False,
        # Optional training parameters:
        num_train_epochs=3,
        per_device_train_batch_size=CONFIG["batch_size"],
        per_device_eval_batch_size=CONFIG["batch_size"],
        gradient_accumulation_steps=CONFIG["accumulation_step"],
        load_best_model_at_end=True,
        metric_for_best_model = "eval_evaluator",
        learning_rate=5e-6,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
        # Optional tracking/debugging parameters:
        max_grad_norm = 1.0,
        deepspeed = CONFIG["deepspeed_config_path"],
        disable_tqdm=False,
        logging_dir = CONFIG["logging_dir"],
        eval_strategy="steps",
        eval_steps=CONFIG["eval_steps"],
        save_strategy="steps",
        save_steps=CONFIG["eval_steps"],
        # save_steps = 10,
        save_total_limit=20,
        logging_steps=10,
        save_safetensors=False,
        eval_on_start=False,
        torch_empty_cache_steps=None,
        # ignore_data_skip=True
    )

    search_evaluator = SearchEvaluator(dev_dict, product_descriptions)

    train_loss = MultipleNegativesRankingLoss(model)

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=final_train_dataset,
        eval_dataset=final_dev_dataset,
        loss=train_loss,
        evaluator=search_evaluator,
    )
    trainer.train()

    trainer.save_model()
    trainer.save_state()

if __name__ == "__main__":
    main()

Deepspeed Config (standard config but providing incase it helps)

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 2
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto"
}

Reproduction Steps

I just run deepspeed final_mnr_training_barebones.py

GPU utilisation for different batch sizes

The VRAM usage is fairly constant through training and does not fluctuate much (barely 0.5-1%)

Conclusion

I am unable to understand why the training is slowing down as shown, for larger batch sizes. For using MNR loss, bigger batch sizes are preferred, and the training will also be done faster ideally, given the training works without this issue.

I have spent quite some time to understand what the issue here is, but have been unable to do so. Any help will be appreciated. Thanks!

Expected behavior

It is expected that the training time should not reduce once as it proceeds, since there is no visible case of throttling as well, based on temperature and VRAM usage statistics. So ideally the epoch/time graph should be a straight line for all batch sizes.

Any help regarding this issue will be appreciated. Thanks!

qubvel commented 1 week ago

Link issue on sentence transformers repo: