huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.7k stars 26.44k forks source link

Trainer class causes massive memory leak when using mps #33717

Open JamesBowerXanda opened 2 days ago

JamesBowerXanda commented 2 days ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

from datasets import load_dataset
from transformers import Trainer, TrainingArguments, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, DataCollatorForLanguageModeling

dataset = load_dataset("roneneldan/TinyStories")
config = {
            "vocab_size": 32000,
            "hidden_size": 128,
            "num_hidden_layers": 4,
            "num_attention_heads": 8,
            "intermediate_size": 256,
            "hidden_act": "silu",
            "max_position_embeddings": 51,
            "initializer_range": 0.02,
            "rms_norm_eps": 1e-6,
        }

model_config = LlamaConfig(**config)
model = LlamaForCausalLM(model_config)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b")
tokenizer.pad_token_id = tokenizer.eos_token_id

dataset["train"] = dataset["train"].select(range(100000))
dataset["validation"] = dataset["validation"].select(range(1000))
def tokenize_function(examples):
    return tokenizer(examples["text"], max_length=4096)
tokenized_datasets = dataset.map(tokenize_function, batched=True)

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

args = TrainingArguments(
    output_dir=f"llama-tiny-stories",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    evaluation_strategy="steps",
    eval_steps=3000,
    logging_steps=100,
    gradient_accumulation_steps=1,
    num_train_epochs=6,
    weight_decay=1e-5,
    learning_rate=1e-4,
    save_steps=0,
    fp16=False,
    push_to_hub=False,
    use_mps_device=True,
    lr_scheduler_type="constant"
)  

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
)

trainer.train()
  1. Open activity monitor
  2. Run the snippet above (tested both in notebook and as a script).

Memory of python process will keep growing continuously and training will be super slow compared to when you set use_mps_device=False and add use_cpu=True.

Expected behavior

  1. I would expect the script to run with a constant memory consumption similar to when run as a cpu run.
  2. I would expect the training to be faster that the cpu run.
RUFFY-369 commented 2 days ago

Hi @JamesBowerXanda can you try adding skip_memory_metrics=False in TrainingArguments

JamesBowerXanda commented 2 days ago

@RUFFY-369 I have tried that and still get the same issue with the memory creeping up and it progressing slower than with only cpu.

RUFFY-369 commented 2 days ago

@JamesBowerXanda Have you tried training a simple pytorch model to test the memory build up issue of MPS device. So that we can confirm if it is a pytorch mps issue or a transformers issue

JamesBowerXanda commented 2 days ago

I have not, I will give that a go now.

RUFFY-369 commented 2 days ago

Sure, and then you can update the same. Cheers!

JamesBowerXanda commented 2 days ago

I ran quick fashion mnist training script with pure pytorch and everything behaved as expected with the memory not leaking and the mps running faster than cpu.

I will try and implement the training loop for the Llama model in pure pytorch to confirm whether the error is to do with the Trainer object or if it is something to do with how mps deals with the specific layers in the transformer architecture.

JamesBowerXanda commented 2 days ago

@RUFFY-369 Ok I have implemented a cruder version of the training loop using the LlamaForCausalLM model and it seems to again work as expected with the memory staying pretty much constant and running much faster than the only cpu implementation.

There is a slight increase in memory while the training is running, so far it has gone from 2.88GB to 3GB but nothing compared to before when it was jumping up in the gigabytes really quickly.

Here is the code I used for reference:

from datasets import load_dataset
from transformers import Trainer, TrainingArguments, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, DataCollatorForLanguageModeling
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch

device = "mps"

dataset = load_dataset("roneneldan/TinyStories")
config = {
            "vocab_size": 32000,
            "hidden_size": 128,
            "num_hidden_layers": 4,
            "num_attention_heads": 8,
            "intermediate_size": 256,
            "hidden_act": "silu",
            "max_position_embeddings": 512,
            "initializer_range": 0.02,
            "rms_norm_eps": 1e-6,
        }

model_config = LlamaConfig(**config)
model = LlamaForCausalLM(model_config)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b")
tokenizer.pad_token_id = tokenizer.eos_token_id

dataset["train"] = dataset["train"].select(range(100000))
dataset["validation"] = dataset["validation"].select(range(1000))
def tokenize_function(examples):
    return tokenizer(examples["text"], max_length=4096)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
optimizer = AdamW(model.parameters(), lr=5e-5)

def collate_fn(batch):
    input_ids = [sample["input_ids"] for sample in batch]
    attention_mask = [sample["attention_mask"] for sample in batch]

    for i in range(len(input_ids)):
        input_id = input_ids[i]
        if len(input_id) < 512:
            input_ids[i] = input_id + [tokenizer.pad_token_id] * (512 - len(input_id))
            attention_mask[i] = [1] * len(input_id) + [0] * (512 - len(input_id))
        else:
            input_ids[i] = input_id[:512]
            attention_mask[i] = [1] * 512

    labels = [
        input_id[1:] + [tokenizer.pad_token_id] for input_id in input_ids
    ]
    return {
        "input_ids": torch.tensor(input_ids).to(device),
        "attention_mask": torch.tensor(attention_mask).to(device),
        "labels": torch.tensor(labels).to(device)
    }

train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=4, shuffle=True, collate_fn=collate_fn)

model.to(device)

for i, batch in enumerate(train_dataloader):
    optimizer.zero_grad()

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    logits = model(input_ids, attention_mask=attention_mask).logits

    loss = torch.nn.CrossEntropyLoss()(logits.view(-1, logits.shape[-1]), labels.view(-1))

    loss.backward()

    optimizer.step()

    if i % 100 == 0:
        print(f"Loss: {loss.item()}")
RUFFY-369 commented 2 days ago

@JamesBowerXanda , okay, then the issue arises in the Trainer class train call. Okay, lets do this: In the training args, can you try adding torch_empty_cache_steps= to a certain value. Basically the value should be an integer greater than zero, i.e., the number of steps after which you want to free the piling up memory cache.

RUFFY-369 commented 2 days ago

@JamesBowerXanda Have you tried this?

RUFFY-369 commented 2 days ago

This arg has a default value of None so may be that's why the memory piles up

JamesBowerXanda commented 1 day ago

@RUFFY-369 I have just added that in and it does seem to slow the memory creep but it is still there and the training incredibly slow. A couple of tests I did had the following timings:

  1. Using trainer with cpu it was taking 1 minute 55 seconds to do 30 steps
  2. Using trainer with mps it was taking 2 minutes 36 secods to do 30 steps
  3. With my training loop (which I appreciate isn't a completely fair comparison because it didn't implement all the logging and other parts) it was doing 30 steps in a second. This is with the same model, same batch size and padding all the inputs to the max model token length.
JamesBowerXanda commented 1 day ago

I should also add I changed the preprocessing of the dataset in the original script to tokenize to max length 512 since that was the max positional embeddings set by the model. I made this change before running the experiments that I just quoted.

JamesBowerXanda commented 1 day ago

The change in speed seems to be very extreme so I am concerned I have done something very stupid somewhere but for the life of me I can't see where.

RUFFY-369 commented 14 hours ago

Reproduction

from datasets import load_dataset
from transformers import Trainer, TrainingArguments, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, DataCollatorForLanguageModeling

dataset = load_dataset("roneneldan/TinyStories")
config = {
            "vocab_size": 32000,
            "hidden_size": 128,
            "num_hidden_layers": 4,
            "num_attention_heads": 8,
            "intermediate_size": 256,
            "hidden_act": "silu",
            "max_position_embeddings": 51,
            "initializer_range": 0.02,
            "rms_norm_eps": 1e-6,
        }

model_config = LlamaConfig(**config)
model = LlamaForCausalLM(model_config)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b")
tokenizer.pad_token_id = tokenizer.eos_token_id

dataset["train"] = dataset["train"].select(range(100000))
dataset["validation"] = dataset["validation"].select(range(1000))
def tokenize_function(examples):
    return tokenizer(examples["text"], max_length=4096)
tokenized_datasets = dataset.map(tokenize_function, batched=True)

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

args = TrainingArguments(
    output_dir=f"llama-tiny-stories",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=16,
    evaluation_strategy="steps",
    eval_steps=3000,
    logging_steps=100,
    gradient_accumulation_steps=1,
    num_train_epochs=6,
    weight_decay=1e-5,
    learning_rate=1e-4,
    save_steps=0,
    fp16=False,
    push_to_hub=False,
    use_mps_device=True,
    lr_scheduler_type="constant"
)  

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
)

trainer.train()

Hi @JamesBowerXanda , I tested your original script on cuda and compared the speed with CPU. The result was as follows:

  • On cuda, the script did 6000 steps in 4 min, 21 seconds with 23.6 it/s.

Screenshot from 2024-09-28 15-05-31

So, you can clearly see that the Trainer works well, modeling script works well, your training script works well because cuda is goinggg brrrrrrrrr :) .

Your test script that you made for comparing simple pytorch model training, worked possibly for MPS as desired because there were no intricacies of the Trainer and the only issue here I think is MPS device running the pytorch model training. And it happens many a times as MPS is quite new to pytorch.

JamesBowerXanda commented 13 hours ago

@RUFFY-369 Ok so there is some interaction between the Trainer script and pytorch when running on MSP? Is this something that might be picked up since use_mps_device is an argument for the TrainingArguments object?

RUFFY-369 commented 13 hours ago

Is this something that might be picked up since use_mps_device is an argument for the TrainingArguments object?

@JamesBowerXanda Well use_mps_device itself is depreciated

JamesBowerXanda commented 13 hours ago

@RUFFY-369 sorry yes but the trainer does default to using the mps device.

RUFFY-369 commented 12 hours ago

@RUFFY-369 Ok so there is some interaction between the Trainer script and pytorch when running on MSP? Is this something that might be picked up since use_mps_device is an argument for the TrainingArguments object?

@JamesBowerXanda Well not particularly, I guess, and also because there isn't anything specific of MPS in the Trainer class. It is basically just MPS backend where the pytorch OPS are implemented as custom metal shaders and then placed on 'mps` device. There has to be something to do with the cache, I assume.

Anyhow, for this, we may need a second opinion as I don't have access to any MPS system to run the code and test.

cc @muellerzr @SunMarc @ydshieh