huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.87k stars 27.2k forks source link

Huge Num Epochs (9223372036854775807) when using Trainer API with streaming dataset #22757

Closed oonisim closed 1 year ago

oonisim commented 1 year ago

System Info

System Info

Running on SageMaker Studio g4dn 2xlarge.

!cat /etc/os-release
PRETTY_NAME="Debian GNU/Linux 10 (buster)"
!transformers-cli env
- `transformers` version: 4.28.0
- Platform: Linux-4.14.309-231.529.amzn2.x86_64-x86_64-with-debian-10.6
- Python version: 3.7.10
- Huggingface_hub version: 0.13.4
- Safetensors version: not installed
- PyTorch version (GPU?): 1.13.1+cu117 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: YES
- Using distributed or parallel set-up in script?: <fill in>
!nvidia-smi
Fri Apr 14 04:32:30 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:1E.0 Off |                    0 |
| N/A   32C    P0    25W /  70W |  13072MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

Background

Fine tune BLOOM model for summarization.

Problem

When using the streaming huggingface dataset, Trainer API shows huge Num Epochs = 9,223,372,036,854,775,807.

trainer.train()
-----
***** Running training *****
  Num examples = 6,144
  Num Epochs = 9,223,372,036,854,775,807      <-----
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 6,144
  Number of trainable parameters = 559,214,592

The TrainingArguments used:

DATASET_STREAMING: bool = True
NUM_EPOCHS: int = 3
DATASET_TRAIN_NUM_SELECT: int = 2048
MAX_STEPS: int = NUM_EPOCHS * DATASET_TRAIN_NUM_SELECT if DATASET_STREAMING else -1

training_args = TrainingArguments(
    output_dir="bloom_finetuned",
    max_steps=MAX_STEPS,                        # <--- 2048 * 3
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=2e-5,
    weight_decay=0.01, 
    no_cuda=False,
)

When not using streaming DATASET_STREAMING=False as in the code, the Num Epochs is displayed as expected.

***** Running training *****
  Num examples = 2,048
  Num Epochs = 3
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 6,144
  Number of trainable parameters = 559,214,592

Who can help?

trainer: @sgugger

Information

Tasks

Reproduction

Run the code below

! pip install torch transformers datasets evaluate scikit-learn rouge rouge-score promptsource --quiet
import multiprocessing
import re
from typing import (
    List,
    Dict,
    Callable,
)

import evaluate
import numpy as np
from datasets import (
    load_dataset,
    get_dataset_split_names
)
from promptsource.templates import (
    DatasetTemplates,
    Template
)
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
    DataCollatorForSeq2Seq,
    BloomForCausalLM,
    TrainingArguments,
    Trainer
)

## Huggingface Datasets
DATASET_NAME: str = "xsum"
DATASET_STREAMING: bool = True                    # If using Dataset streaming
DATASET_TRAIN_NUM_SELECT: int = 2048       # Number of rows to use for training
DATASET_VALIDATE_NUM_SELECT: int = 128

# Huggingface Tokenizer (BLOOM default token length is 2048)
MAX_TOKEN_LENGTH: int = 512         # Max token length to avoid out of memory
PER_DEVICE_BATCH_SIZE: int = 1       # GPU batch size

# Huggingface Model
MODEL = "bigscience/bloom-560m"

# Training
NUM_EPOCHS: int = 3
MAX_STEPS: int = NUM_EPOCHS * DATASET_TRAIN_NUM_SELECT if DATASET_STREAMING else -1

train = load_dataset("xsum", split="train", streaming=DATASET_STREAMING)

prompt_templates = DatasetTemplates( dataset_name=DATASET_NAME)
template: Template = prompt_templates['summarize_DOC']

# # Preprocess
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)

def get_convert_to_request_response(template: Template) -> Callable:
    def _convert_to_prompt_response(example: Dict[str, str]) -> Dict[str, str]:
        """Generate prompt, response as a dictionary:
        {
            "prompt": "Summarize: ...",
            "response": "..."
        }

        NOTE: DO NOT use with dataset map function( batched=True). Use batch=False

        Args:
            example: single {document, summary} pair to be able to apply template
        Returns: a dictionary of pro
        """
        # assert isinstance(example, dict), f"expected dict but {type(example)}.\n{example}"
        assert isinstance(example['document'], str), f"expected str but {type(example['document'])}."
        prompt, response = template.apply(example=example, truncate=False)
        return {
            "prompt": re.sub(r'[\s\'\"]+', ' ', prompt),
            "response": re.sub(r'[\s\'\"]+', ' ', response)
        }

    return _convert_to_prompt_response

convert_to_request_response: Callable = get_convert_to_request_response(template=template)

def tokenize_prompt_response(examples):
    """Generate the model inputs in the dictionary with format:
    {
        "input_ids": List[int], 
        "attention_mask": List[int]",
        "labels": List[int]
    }

    Note: Huggngface dataaset map(batched=True, batch_size=n) merges values of 
    n dictionarys into a values of the key. If you have n instances of {"key", "v"}, then
    you will get {"key": ["v", "v", "v", ...] }.

    Args:
        examples:   a dictionary of format {
            "prompt": [prompt+],
            "response": [respnse+]
        } where + means more than one instance because of Dataset.map(batched=True)
    """    
    inputs: Dict[str, List[int]] = tokenizer(
        text_target=examples["prompt"], 
        max_length=MAX_TOKEN_LENGTH, 
        truncation=True
    )

    labels: Dict[str, List[int]] = tokenizer(
        text_target=examples["response"], 
        max_length=MAX_TOKEN_LENGTH, 
        truncation=True,
        padding='max_length',
    )
    inputs["labels"] = labels["input_ids"]

    return inputs

remove_column_names: List[str] = list(train.features.keys())
tokenized_train = train.map(
    function=convert_to_request_response, 
    batched=False,
    batch_size=2048,
    drop_last_batch=False,
    remove_columns=remove_column_names,
).map(
    function=tokenize_prompt_response, 
    batched=True,
    batch_size=32,
    drop_last_batch=True,
    remove_columns=['prompt', 'response']
).shuffle(
    seed=42
).with_format(
    "torch"
)

if DATASET_STREAMING:
    tokenized_train = tokenized_train.take(DATASET_TRAIN_NUM_SELECT)
else:
    tokenized_train = tokenized_train.select(
        indices=range(DATASET_TRAIN_NUM_SELECT)
    )

del train

tokenized_validation =  load_dataset(
    path="xsum", 
    split="validation", 
    streaming=DATASET_STREAMING
).map(
    function=convert_to_request_response, 
    batched=False,
    batch_size=2048,
    drop_last_batch=False,
    remove_columns=remove_column_names,
).map(
    function=tokenize_prompt_response, 
    batched=True,
    batch_size=32,
    drop_last_batch=True,
    remove_columns=['prompt', 'response']
).with_format(
    "torch"
)

if DATASET_STREAMING:
    tokenized_validation = tokenized_validation.take(DATASET_TRAIN_NUM_SELECT)
else:
    tokenized_validation = tokenized_validation.select(
        indices=range(DATASET_TRAIN_NUM_SELECT)
    )

# # Training
model = BloomForCausalLM.from_pretrained(MODEL)
model.cuda()

def predict(prompt: str) -> str:
    inputs = tokenizer(prompt, return_tensors='pt')
    print(inputs["input_ids"].shape)

    response_tokens = model.generate(
        inputs["input_ids"].cuda(), 
        max_new_tokens=1,
        do_sample=False, 
        top_k=50, 
        top_p=0.9
    )[0]
    response = tokenizer.decode(response_tokens, skip_special_tokens=True)
    return response

# DataCollatorWithPadding does not pad 'labels' which causes an error at train()
# https://stackoverflow.com/a/74228547/4281353
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, 
    padding='max_length',
    pad_to_multiple_of=8,
    max_length=MAX_TOKEN_LENGTH,
    return_tensors='pt'
)

# ## Evaluation
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

# ## Trainer API
training_args = TrainingArguments(
    output_dir="bloom_finetuned",
    max_steps=MAX_STEPS,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    learning_rate=2e-5,
    weight_decay=0.01, 
    fp16=True,
    no_cuda=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    log_level="debug",
    disable_tqdm=False,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_validation,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Expected behavior

Get the intended epochs 3 or explanation of the Num Epochs (9223372036854775807).

When not using streaming DATASET_STREAMING=False as in the code, the Num Epochs is displayed as expected.

***** Running training *****
  Num examples = 2,048
  Num Epochs = 3
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 6,144
  Number of trainable parameters = 559,214,592

Related

sgugger commented 1 year ago

That's because the dataset you are using does not have a length, so the Trainer sets the number of epochs to a very high number to make sure it does the number of steps you are asking for.

oonisim commented 1 year ago

@sgugger , thanks for the explanation.

May I suggest updating the document adding the Trainer behavior and requirements for streaming dataset e.g. to use max_steps and what value to set. Otherwise users may keep raising questions on max_steps (there have been at least 3 questions in forum) and epochs?

I am afraid otherwise you may need to spend your time for each time we raise it.

Currently Datasets - Stream and Trainer documents have no such information as far as I looked at (please correct if there is).

sgugger commented 1 year ago

We welcome any PR making the documentation better :-)