mingkaid / rl-prompt

Accompanying repo for the RLPrompt paper
MIT License
286 stars 52 forks source link

classifcation with gpt #26

Closed MatthewCYM closed 1 year ago

MatthewCYM commented 1 year ago

Hi,

May I ask if current code supports using gpt on classification tasks?

Thanks.

MM-IR commented 1 year ago

Hi,

In our code implementation, we have already supported GPT-2 models.

Thanks, since it is a clarification question, I am closing this now.

MatthewCYM commented 1 year ago

Thanks for the quick reply. I run the code with

python run_fsc.py \
    dataset=agnews \
    dataset_seed=0 \
    prompt_length=5 \
    task_lm=roberta-large \
    random_seed=42 \
    report_to_wandb=false

The training takes around 1 day to complete on a single RTX3090, which is much longer than the training time reported in the paper (4 hr). May I ask if this is normal?

I also try to run the code with gpt2 backbone:

python run_fsc.py \
    dataset=agnews \
    dataset_seed=0 \
    prompt_length=5 \
    task_lm=gpt2-xl \
    random_seed=42 \
    report_to_wandb=false

The eval accuracy is only 62.5. Have you experimented with GPT2 on the classification task?