Open zhaochenyang20 opened 1 year ago
I don't find an easy and clear way to ban MPS
usage. 🤔
The following code failed.
training_args = Seq2SeqTrainingArguments(
output_dir=hyperparameter_choices.get("output_dir", "./result"),
logging_steps=hyperparameter_choices.get("logging_steps", 1),
save_strategy=hyperparameter_choices.get("save_strategy", "no"),
num_train_epochs=hyperparameter_choices.get("num_train_epochs", 10),
per_device_train_batch_size=hyperparameter_choices.get(
"per_device_train_batch_size", 100
),
warmup_steps=hyperparameter_choices.get("warmup_steps", 0),
weight_decay=hyperparameter_choices.get("weight_decay", 0.01),
logging_dir=hyperparameter_choices.get("logging_dir", "./result"),
learning_rate=hyperparameter_choices.get("learning_rate", 1e-4),
predict_with_generate=True,
use_mps_device=False,
)
To clarify, what is the error you get when you run that code?
Emmm. I gonna mention it later.
Our current trainer does not support MPS training.