Open katherine-atwell opened 3 months ago
cc @muellerzr @SunMarc
Hey @katherine-atwell , Is the generator device set to mps ?
Hey @katherine-atwell, can you try the following code to test if it works :
from datasets import Dataset
import torch
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, AutoTokenizer
from datasets import load_dataset
# Load the IMDb dataset
imdb = load_dataset("imdb")
class TweetDataset(Dataset):
def __init__(self, data, tokenizer, max_length):
self._data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.encodings = tokenizer(
[row["text"] for row in data],
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
self.labels = torch.tensor([row["label"] for row in data])
def __getitem__(self, idx):
# Return the dict of tokenized tensors and add the label
item = {key: val[idx] for key, val in self.encodings.items()}
item["labels"] = self.labels[idx]
return item
def __len__(self):
return len(self._data)
base_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(base_model)
tokenizer.pad_token = tokenizer.eos_token
# Create training and validation datasets
train_data = TweetDataset(imdb["train"], tokenizer, 1024)
val_data = TweetDataset(imdb["test"], tokenizer, 1024)
num_labels = 2
model = AutoModelForSequenceClassification.from_pretrained(base_model, num_labels=num_labels)
model.config.pad_token_id = model.config.eos_token_id
model.to("mps")
# Define training arguments
training_args = TrainingArguments(
remove_unused_columns=False,
output_dir="./"
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
)
# Start training
trainer.train()
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
transformers
version: 4.37.2Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
While trying to run Trainer.train() on a Mac device, I run into the following error:
This error is caused by the following code:
train_data and val_data are instances of the following custom dataset:
Expected behavior
Initializing generators on MPS, as opposed to CPU, and not throwing an error