neulab / prompt2model

prompt2model - Generate Deployable Models from Natural Language Instructions
Apache License 2.0
1.96k stars 177 forks source link

Support for MPS device #317

Open zhaochenyang20 opened 1 year ago

zhaochenyang20 commented 1 year ago

Our current trainer does not support MPS training.

zhaochenyang20 commented 1 year ago

I don't find an easy and clear way to ban MPS usage. 🤔

The following code failed.

        training_args = Seq2SeqTrainingArguments(
            output_dir=hyperparameter_choices.get("output_dir", "./result"),
            logging_steps=hyperparameter_choices.get("logging_steps", 1),
            save_strategy=hyperparameter_choices.get("save_strategy", "no"),
            num_train_epochs=hyperparameter_choices.get("num_train_epochs", 10),
            per_device_train_batch_size=hyperparameter_choices.get(
                "per_device_train_batch_size", 100
            ),
            warmup_steps=hyperparameter_choices.get("warmup_steps", 0),
            weight_decay=hyperparameter_choices.get("weight_decay", 0.01),
            logging_dir=hyperparameter_choices.get("logging_dir", "./result"),
            learning_rate=hyperparameter_choices.get("learning_rate", 1e-4),
            predict_with_generate=True,
            use_mps_device=False,
        )
neubig commented 1 year ago

To clarify, what is the error you get when you run that code?

zhaochenyang20 commented 1 year ago

Emmm. I gonna mention it later.