mingkaid / rl-prompt

Accompanying repo for the RLPrompt paper
MIT License
286 stars 52 forks source link

assert len(prompt_strs) == len(source_strs) fails during inference #15

Closed mahdiabdollahpour closed 1 year ago

mahdiabdollahpour commented 1 year ago

Hi

Thanks for sharing the code.

in https://github.com/mingkaid/rl-prompt/blob/main/examples/text-style-transfer/tst_reward.py#L65

during inference assert fails (while it works on the train). in the inference, the lengths are 16 and 4 (while 8 and 8 in the train). Do you know if there is any update needed in the code here?

    def forward(
        self,
        source_texts: List[str],
        target_labels: List[str],
        output_tokens: List[List[str]],
        to_tensor: bool,
        mode: str
    ) -> Tuple[Union[List[float], torch.Tensor], Dict[str, Any]]:
        if mode == 'train':
            self._counter += 1
            source_strs = self._repeat_texts(source_texts)
            target_labels = self._repeat_texts(target_labels)
        elif mode == "infer":
            source_strs = source_texts
        else:
            raise ValueError

        prompt_tokens = output_tokens
        prompt_strs = self._convert_tokens_to_string(prompt_tokens)
        assert len(prompt_strs) == len(source_strs)

I run the code using the command:

python run_tst.py \
    dataset=shakespeare \
    dataset_seed=1 \
    direction=0_to_1 \
    prompt_length=5 \
    task_lm=distilgpt2\
    lower_outputs=false\
    random_seed=100
mingkaid commented 1 year ago

Hi! Thank you for your interest and sorry for the delay. Could you provide the commands to reproduce the error, starting from a fresh environment?

mahdiabdollahpour commented 1 year ago

These are the exact commands

virtualenv --python=python3 rl_env2
source rl_env2/bin/activate

git clone https://github.com/mingkaid/rl-prompt

cd rl-prompt/
pip install -e .

cd examples/text-style-transfer/
pip install -r requirements.txt

python scripts/download_tst_classifiers.py  --model_name yelp-train
python scripts/download_tst_classifiers.py  --model_name shakespeare-train-100-0
python scripts/download_tst_classifiers.py  --model_name shakespeare-train-100-1
python scripts/download_tst_classifiers.py  --model_name  shakespeare-train-100-2

then

python run_tst.py \
    dataset=shakespeare \
    dataset_seed=1 \
    direction=0_to_1 \
    prompt_length=5 \
    task_lm=distilgpt2\
    lower_outputs=false\
    random_seed=100

(choosing option 3 for wandb)

it runs for a while and then gives the error:

error executing job with overrides: ['dataset=shakespeare', 'dataset_seed=1', 'direction=0_to_1', 'prompt_length=5', 'task_lm=distilgpt2', 'lower_outputs=false', 'random_seed=100']
Traceback (most recent call last):
  File "run_tst.py", line 47, in main
    trainer.train(config=config)
  File "/media/Storage/CTG/rl_prompt2/rl-prompt/rlprompt/trainers/trainer.py", line 167, in train
    eval_log = self.evaluate(output_save_path=output_save_path)
  File "/media/Storage/CTG/rl_prompt2/rl-prompt/rlprompt/trainers/trainer.py", line 216, in evaluate
    output_tokens=infer_outputs['sample_tokens'])
  File "/media/Storage/CTG/rl_prompt2/rl-prompt/rlprompt/modules/sql_module.py", line 165, in compute_rewards
    mode=mode)
  File "/media/Storage/CTG/rl_prompt2/rl-prompt/rlprompt/rewards/base_reward.py", line 3, in __call__
    return self.forward(*args, **kwargs)
  File "/media/Storage/CTG/rl_prompt2/rl-prompt/examples/text-style-transfer/tst_reward.py", line 84, in forward
    assert len(prompt_strs) == len(source_strs)
AssertionError
mahdiabdollahpour commented 1 year ago

and here is the pip freeze:


antlr4-python3-runtime==4.9.3
appdirs==1.4.4
bert-score==0.3.12
certifi==2022.12.7
charset-normalizer==3.0.1
click==8.1.3
colorama==0.4.6
cycler==0.11.0
docker-pycreds==0.4.0
filelock==3.9.0
fonttools==4.38.0
gitdb==4.0.10
GitPython==3.1.30
huggingface-hub==0.12.0
hydra-core==1.2.0
idna==3.4
importlib-metadata==6.0.0
importlib-resources==5.10.2
kiwisolver==1.4.4
lxml==4.9.2
matplotlib==3.5.3
numpy==1.21.6
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
omegaconf==2.3.0
packaging==23.0
pandas==1.3.5
pathtools==0.1.2
Pillow==9.4.0
pkg_resources==0.0.0
portalocker==2.7.0
protobuf==4.21.12
psutil==5.9.4
pyparsing==3.0.9
python-dateutil==2.8.2
pytz==2022.7.1
PyYAML==6.0
regex==2022.10.31
requests==2.28.2
-e git+https://github.com/mingkaid/rl-prompt@1802e394c57df752597a44cc7f17657e03b52f04#egg=rl_prompt
sacrebleu==2.3.1
sentry-sdk==1.14.0
setproctitle==1.3.2
six==1.16.0
smmap==5.0.0
tabulate==0.9.0
tokenizers==0.13.2
torch==1.13.1
tqdm==4.64.1
transformers==4.26.0
typing==3.7.4.3
typing_extensions==4.4.0
urllib3==1.26.14
wandb==0.13.9
zipp==3.12.0
mingkaid commented 1 year ago

Thank you for sharing the commands, and again for your feedback! Upon investigation, it was indeed a bug in our code. We have fixed it and updated the code in #16. Feel free to pull the updates and try again.

I'll close the issue now because the problem has been addressed. Let me know if you have any questions.