Muennighoff / sgpt

SGPT: GPT Sentence Embeddings for Semantic Search
https://arxiv.org/abs/2202.08904
MIT License
823 stars 51 forks source link

When I train an encoder using bloom 3b, I get this error? What is the cause of this problem, please? #27

Closed ScottishFold007 closed 1 year ago

ScottishFold007 commented 1 year ago

Here is the command line I used:

### https://huggingface.co/docs/accelerate/basic_tutorials/launch

!accelerate launch --multi_gpu --mixed_precision bf16 --num_processes 7 train_bi-encoder_mnrl.py \
--train_batch_size 8  \
--eval_batch_size 8 \
--lr 2e-5  \
--epochs 5 \
--asym \
--pooling weightedmean \
--max_seq_length 512 \
--pooling weightedmean \
--wandbwatchlog gradients \
--specb  \
--freezenonbias  \
--gradcache \
--chunksize 4

Then the following error was reported:

NotImplementedError: Model input split not implemented for type <class 'dict'>
Iteration:   0%|                                       | 0/2743 [00:01<?, ?it/s]
Epoch:   0%|                                              | 0/5 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/root/data_process/train_bi-encoder_mnrl.py", line 375, in <module>
    model.fit(train_objectives=[(train_dataloader, train_loss)],
  File "/root/data_process/sentence_transformers/SentenceTransformer.py", line 801, in fit
    loss_value = loss_model(features, labels)
  File "/root/data_process/sentence_transformers/losses/MultipleNegativesRankingLoss.py", line 153, in __call__
    return super().__call__(*sentence_features, no_sync_except_last=no_sync_except_last)
  File "/data/anaconda3/lib/python3.10/site-packages/grad_cache/grad_cache.py", line 70, in __call__
    return self.cache_step(*args, **kwargs)
  File "/data/anaconda3/lib/python3.10/site-packages/grad_cache/grad_cache.py", line 266, in cache_step
    model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)]
  File "/data/anaconda3/lib/python3.10/site-packages/grad_cache/grad_cache.py", line 266, in <listcomp>
    model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)]
  File "/data/anaconda3/lib/python3.10/site-packages/grad_cache/grad_cache.py", line 102, in split_inputs
    raise NotImplementedError(f'Model input split not implemented for type {type(model_input)}')
NotImplementedError: Model input split not implemented for type <class 'dict'>
Iteration:   0%|                                       | 0/2743 [00:01<?, ?it/s]
Epoch:   0%|                                              | 0/5 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/root/data_process/train_bi-encoder_mnrl.py", line 375, in <module>
    model.fit(train_objectives=[(train_dataloader, train_loss)],
  File "/root/data_process/sentence_transformers/SentenceTransformer.py", line 801, in fit
    loss_value = loss_model(features, labels)
  File "/root/data_process/sentence_transformers/losses/MultipleNegativesRankingLoss.py", line 153, in __call__
    return super().__call__(*sentence_features, no_sync_except_last=no_sync_except_last)
  File "/data/anaconda3/lib/python3.10/site-packages/grad_cache/grad_cache.py", line 70, in __call__
    return self.cache_step(*args, **kwargs)
  File "/data/anaconda3/lib/python3.10/site-packages/grad_cache/grad_cache.py", line 266, in cache_step
    model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)]
  File "/data/anaconda3/lib/python3.10/site-packages/grad_cache/grad_cache.py", line 266, in <listcomp>
    model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)]
  File "/data/anaconda3/lib/python3.10/site-packages/grad_cache/grad_cache.py", line 102, in split_inputs
    raise NotImplementedError(f'Model input split not implemented for type {type(model_input)}')
NotImplementedError: Model input split not implemented for type <class 'dict'>

But of course, when I set asym to False, it works perfectly, I don't know what the problem is? Can you help me out? Thank you!

Muennighoff commented 1 year ago

I think you don't want to use the --asym argument because it creates a separate model for queries and documents, which performs much worse than the default tieing of the weights of query and document encoders. I haven't tested the --asym kwarg with GradCache - I think the problem is that the model for --asym is not a single torch model but two models (i.e. a dict) for query & doc model. Anyways since --asym performs badly, I just wouldn't use it.

ScottishFold007 commented 1 year ago

asym

But I'm thinking, if the query is not long, but the passage is very long, this does not use asym this form, can guarantee a good search effect? Will there be inaccuracies? Also, the effect becomes worse, is it because the length of the passages in the training set is generally not very long.

Muennighoff commented 1 year ago

But I'm thinking, if the query is not long, but the passage is very long, this does not use asym this form, can guarantee a good search effect?

The way SGPT distinguishes them is via --specb, which adds different brackets depending on if it's query or document. I had the same intuition as you, but it seems that --asym just does not work well. You can still try it with the code as is but you will need to remove --gradcache --chunksize 4(which will need more memory) or just make it compatible.

ScottishFold007 commented 1 year ago

The way SGPT distinguishes them is via --specb, which adds different brackets depending on if it's query or document. I had the same intuition as you, but it seems that --asym just does not work well. You can still try it with the code as is but you will need to remove --gradcache --chunksize 4(which will need more memory) or just make it compatible. Well, there is one more question, is it also necessary to use """if "gpt" in model_name for a model with an architecture like bloom: accelerator = Accelerator()"" instead of ""other:

Need to run e.g. bert-large-uncased (also works for GPT, but uses unnecessary memory)

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])"" image

Muennighoff commented 1 year ago

If it doesn't error out then I think it's fine - you can probably use less memory if you change it to if ("gpt" in model_name) or ("bloom" in model_name):

ScottishFold007 commented 1 year ago

If it doesn't error out then I think it's fine - you can probably use less memory if you change it to if ("gpt" in model_name) or ("bloom" in model_name):

Okay, I understand. Thank you very much for your attentive answer! I wish you a happy life! All the best!

Muennighoff commented 1 year ago

If it doesn't error out then I think it's fine - you can probably use less memory if you change it to if ("gpt" in model_name) or ("bloom" in model_name):

Okay, I understand. Thank you very much for your attentive answer! I wish you a happy life! All the best!

Happy to be of help! 👍

ScottishFold007 commented 1 year ago

But I'm thinking, if the query is not long, but the passage is very long, this does not use asym this form, can guarantee a good search effect?

The way SGPT distinguishes them is via --specb, which adds different brackets depending on if it's query or document. I had the same intuition as you, but it seems that --asym just does not work well. You can still try it with the code as is but you will need to remove --gradcache --chunksize 4(which will need more memory) or just make it compatible.

When I removed --gradcache --chunksize 4, the original code had a memory OOM, but when I added it, it worked even with chunksize set to 8.

Muennighoff commented 1 year ago

Yeah you need gradcache for running with high batch sizes