facebookresearch / advprompter

Official implementation of AdvPrompter https//arxiv.org/abs/2404.16873
Other
120 stars 13 forks source link

ValueError: Trying to repeat_interleave sequence with bs=1, this might break broadcasting. #10

Closed ish3lan closed 1 month ago

ish3lan commented 3 months ago

After finishing 98% of the first training epoch:


Initializing Wandb...
wandb: Currently logged in as: XXXXX(XXXXX). Use wandb login --relogin to force relogin
wandb: wandb version 0.17.6 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.16.6
wandb: Run data is saved locally in /home/XXXXX/certification/advprompter/advprompter/wandb/run-20240810_152449-o3oqevnp
wandb: Run wandb offline to turn off syncing.
wandb: Syncing run upbeat-pond-17
wandb: ⭐️ View project at https://wandb.ai/XXXXX/advprompter
wandb: 🚀 View run at https://wandb.ai/XXXXX/advprompter/runs/o3oqevnp
Initializing Prompter...
 Loading model: tiny_llama from TinyLlama/TinyLlama-1.1B-step-50K-105b...
/home/XXXXX/.conda/envs/advprompter/lib/python3.11/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: resume_download is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use force_download=True.
  warnings.warn(
 Mem usage model: 4.42 GB | Total Mem usage: 4.42 GB
 Transforming to LoRA model...
 trainable params: 1398784 || all params: 1101447168 || trainable%: 0.13
Initializing TargetLLM...
 Loading model: vicuna-7b-v1.5 from lmsys/vicuna-7b-v1.5...
Loading checkpoint shards:   0%|                                                                                                                                                                                                                | 0/2 [00:00<?, ?it/s]/home/XXXXX/.conda/envs/advprompter/lib/python3.11/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:35<00:00, 17.55s/it]
/home/XXXXX/.conda/envs/advprompter/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: do_sample is set to False. However, temperature is set to 0.9 -- this flag is only used in sample-based generation modes. You should set do_sample=True or unset temperature. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
/home/XXXXX/.conda/envs/advprompter/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: do_sample is set to False. However, top_p is set to 0.6 -- this flag is only used in sample-based generation modes. You should set do_sample=True or unset top_p. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.
  warnings.warn(
/home/XXXXX/.conda/envs/advprompter/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:492: UserWarning: do_sample is set to False. However, temperature is set to 0.9 -- this flag is only used in sample-based generation modes. You should set do_sample=True or unset temperature.
  warnings.warn(
/home/XXXXX/.conda/envs/advprompter/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:497: UserWarning: do_sample is set to False. However, top_p is set to 0.6 -- this flag is only used in sample-based generation modes. You should set do_sample=True or unset top_p.
  warnings.warn(
 Mem usage model: 13.54 GB | Total Mem usage: 17.97 GB
 Freezing model...
 trainable params: 0 || all params: 6738415616 || trainable%: 0.00
Starting training...
Training (epochs):   0%|                                                                                                                                                                                                                       | 0/10 [00:00<?, ?it/s/home/XXXXX/.conda/envs/advprompter/lib/python3.11/site-packages/torchrl/data/replay_buffers/samplers.py:318: RuntimeWarning: overflow encountered in power                                                                        | 2/48 [05:15<2:03:20, 160.89s/it]
  priority = np.power(priority + self._eps, self._alpha)
/home/XXXXX/.conda/envs/advprompter/lib/python3.11/site-packages/torchrl/data/replay_buffers/samplers.py:229: RuntimeWarning: overflow encountered in scalar power
  return (self._max_priority + self._eps) ** self._alpha
Training epoch 0:  98%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊    | 47/48 [2:04:50<02:39, 159.37s/it]
Training (epochs):   0%|                                                                                                                                                                                                                     | 0/10 [2:04:50<?, ?it/s]
Error executing job with overrides: []
Traceback (most recent call last):
  File "/home/XXXXX/certification/advprompter/advprompter/main.py", line 661, in main
    workspace.train()
  File "/home/XXXXX/certification/advprompter/advprompter/main.py", line 167, in train
    self.train_epoch()
  File "/home/XXXXX/certification/advprompter/advprompter/main.py", line 237, in train_epoch
    suffix = advPrompterOpt(
             ^^^^^^^^^^^^^^^
  File "/home/XXXXX/.conda/envs/advprompter/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/XXXXX/certification/advprompter/advprompter/advprompteropt.py", line 46, in advPrompterOpt
    instruct_rep = instruct.repeat_interleave(num_beams_in, dim=0)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/XXXXX/certification/advprompter/advprompter/sequence.py", line 654, in repeat_interleave
    raise ValueError(
ValueError: Trying to repeat_interleave sequence with bs=1, this might break broadcasting.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.```
arman-z commented 2 months ago

Are you using your custom dataset? @a-paulus should know better this issue. But an easy workaround could be cloning some rows from your dataset so that bs won't be equal to 1...

ish3lan commented 2 months ago

Yes I am using a custom dataset. I am sorry, I did not get your suggestion of cloning a row in my dataset. You mean duplicating any row in my data set? Would you please guide on how is that related to bs? Thank you

arman-z commented 2 months ago

The last line in your log states:

ValueError: Trying to repeat_interleave sequence with bs=1, this might break broadcasting.

So, I'd imagine this is the main cause. Duplicating any row from your dataset may alleviate this problem, although this is a bit clumsy.