Open JamesBowerXanda opened 2 days ago
Hi @JamesBowerXanda can you try adding skip_memory_metrics=False
in TrainingArguments
@RUFFY-369 I have tried that and still get the same issue with the memory creeping up and it progressing slower than with only cpu.
@JamesBowerXanda Have you tried training a simple pytorch model to test the memory build up issue of MPS device. So that we can confirm if it is a pytorch mps issue or a transformers issue
I have not, I will give that a go now.
Sure, and then you can update the same. Cheers!
I ran quick fashion mnist training script with pure pytorch and everything behaved as expected with the memory not leaking and the mps running faster than cpu.
I will try and implement the training loop for the Llama model in pure pytorch to confirm whether the error is to do with the Trainer object or if it is something to do with how mps deals with the specific layers in the transformer architecture.
@RUFFY-369 Ok I have implemented a cruder version of the training loop using the LlamaForCausalLM
model and it seems to again work as expected with the memory staying pretty much constant and running much faster than the only cpu implementation.
There is a slight increase in memory while the training is running, so far it has gone from 2.88GB to 3GB but nothing compared to before when it was jumping up in the gigabytes really quickly.
Here is the code I used for reference:
from datasets import load_dataset
from transformers import Trainer, TrainingArguments, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, DataCollatorForLanguageModeling
from torch.optim import AdamW
from torch.utils.data import DataLoader
import torch
device = "mps"
dataset = load_dataset("roneneldan/TinyStories")
config = {
"vocab_size": 32000,
"hidden_size": 128,
"num_hidden_layers": 4,
"num_attention_heads": 8,
"intermediate_size": 256,
"hidden_act": "silu",
"max_position_embeddings": 512,
"initializer_range": 0.02,
"rms_norm_eps": 1e-6,
}
model_config = LlamaConfig(**config)
model = LlamaForCausalLM(model_config)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b")
tokenizer.pad_token_id = tokenizer.eos_token_id
dataset["train"] = dataset["train"].select(range(100000))
dataset["validation"] = dataset["validation"].select(range(1000))
def tokenize_function(examples):
return tokenizer(examples["text"], max_length=4096)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
optimizer = AdamW(model.parameters(), lr=5e-5)
def collate_fn(batch):
input_ids = [sample["input_ids"] for sample in batch]
attention_mask = [sample["attention_mask"] for sample in batch]
for i in range(len(input_ids)):
input_id = input_ids[i]
if len(input_id) < 512:
input_ids[i] = input_id + [tokenizer.pad_token_id] * (512 - len(input_id))
attention_mask[i] = [1] * len(input_id) + [0] * (512 - len(input_id))
else:
input_ids[i] = input_id[:512]
attention_mask[i] = [1] * 512
labels = [
input_id[1:] + [tokenizer.pad_token_id] for input_id in input_ids
]
return {
"input_ids": torch.tensor(input_ids).to(device),
"attention_mask": torch.tensor(attention_mask).to(device),
"labels": torch.tensor(labels).to(device)
}
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=4, shuffle=True, collate_fn=collate_fn)
model.to(device)
for i, batch in enumerate(train_dataloader):
optimizer.zero_grad()
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
logits = model(input_ids, attention_mask=attention_mask).logits
loss = torch.nn.CrossEntropyLoss()(logits.view(-1, logits.shape[-1]), labels.view(-1))
loss.backward()
optimizer.step()
if i % 100 == 0:
print(f"Loss: {loss.item()}")
@JamesBowerXanda , okay, then the issue arises in the Trainer
class train
call. Okay, lets do this:
In the training args, can you try adding torch_empty_cache_steps=
to a certain value. Basically the value should be an integer greater than zero, i.e., the number of steps after which you want to free the piling up memory cache.
@JamesBowerXanda Have you tried this?
This arg has a default value of None so may be that's why the memory piles up
@RUFFY-369 I have just added that in and it does seem to slow the memory creep but it is still there and the training incredibly slow. A couple of tests I did had the following timings:
I should also add I changed the preprocessing of the dataset in the original script to tokenize to max length 512 since that was the max positional embeddings set by the model. I made this change before running the experiments that I just quoted.
The change in speed seems to be very extreme so I am concerned I have done something very stupid somewhere but for the life of me I can't see where.
Reproduction
from datasets import load_dataset from transformers import Trainer, TrainingArguments, LlamaConfig, LlamaForCausalLM, LlamaTokenizer, DataCollatorForLanguageModeling dataset = load_dataset("roneneldan/TinyStories") config = { "vocab_size": 32000, "hidden_size": 128, "num_hidden_layers": 4, "num_attention_heads": 8, "intermediate_size": 256, "hidden_act": "silu", "max_position_embeddings": 51, "initializer_range": 0.02, "rms_norm_eps": 1e-6, } model_config = LlamaConfig(**config) model = LlamaForCausalLM(model_config) tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b") tokenizer.pad_token_id = tokenizer.eos_token_id dataset["train"] = dataset["train"].select(range(100000)) dataset["validation"] = dataset["validation"].select(range(1000)) def tokenize_function(examples): return tokenizer(examples["text"], max_length=4096) tokenized_datasets = dataset.map(tokenize_function, batched=True) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) args = TrainingArguments( output_dir=f"llama-tiny-stories", per_device_train_batch_size=4, per_device_eval_batch_size=16, evaluation_strategy="steps", eval_steps=3000, logging_steps=100, gradient_accumulation_steps=1, num_train_epochs=6, weight_decay=1e-5, learning_rate=1e-4, save_steps=0, fp16=False, push_to_hub=False, use_mps_device=True, lr_scheduler_type="constant" ) trainer = Trainer( model=model, tokenizer=tokenizer, args=args, data_collator=data_collator, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["validation"], ) trainer.train()
Hi @JamesBowerXanda , I tested your original script on cuda and compared the speed with CPU. The result was as follows:
- On cuda, the script did 6000 steps in 4 min, 21 seconds with 23.6 it/s.
So, you can clearly see that the Trainer
works well, modeling script works well, your training script works well because cuda is goinggg brrrrrrrrr :) .
Your test script that you made for comparing simple pytorch model training, worked possibly for MPS as desired because there were no intricacies of the Trainer
and the only issue here I think is MPS device running the pytorch model training. And it happens many a times as MPS is quite new to pytorch.
@RUFFY-369 Ok so there is some interaction between the Trainer script and pytorch when running on MSP? Is this something that might be picked up since use_mps_device
is an argument for the TrainingArguments object?
Is this something that might be picked up since
use_mps_device
is an argument for the TrainingArguments object?
@JamesBowerXanda Well use_mps_device
itself is depreciated
@RUFFY-369 sorry yes but the trainer does default to using the mps device.
@RUFFY-369 Ok so there is some interaction between the Trainer script and pytorch when running on MSP? Is this something that might be picked up since
use_mps_device
is an argument for the TrainingArguments object?
@JamesBowerXanda Well not particularly, I guess, and also because there isn't anything specific of MPS in the Trainer
class. It is basically just MPS backend where the pytorch OPS are implemented as custom metal shaders and then placed on 'mps` device. There has to be something to do with the cache, I assume.
Anyhow, for this, we may need a second opinion as I don't have access to any MPS system to run the code and test.
cc @muellerzr @SunMarc @ydshieh
System Info
transformers
version: 4.44.2Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Memory of python process will keep growing continuously and training will be super slow compared to when you set
use_mps_device=False
and adduse_cpu=True
.Expected behavior