AnswerDotAI / RAGatouille

Easily use and train state of the art late-interaction retrieval methods (ColBERT) in any RAG pipeline. Designed for modularity and ease-of-use, backed by research.
Apache License 2.0
3.08k stars 210 forks source link

answerai-colbert-small-v1 issue #239

Closed MohammedAlhajji closed 3 months ago

MohammedAlhajji commented 3 months ago

I am not sure what the issue here so it may be me doing something stupid. I have the latest version of ragatoullie

running this code:

from ragatouille import RAGPretrainedModel
from ragatouille.utils import get_wikipedia_page

RAG = RAGPretrainedModel.from_pretrained("answerdotai/answerai-colbert-small-v1")
my_documents = [
    get_wikipedia_page("Hayao_Miyazaki"),
    get_wikipedia_page("Studio_Ghibli"),
]
index_path = RAG.index(index_name="my_index", collection=my_documents)

yields the following error

{
    "name": "AssertionError",
    "message": "(96, 8)",
    "stack": "---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[1], line 9
      4 RAG = RAGPretrainedModel.from_pretrained(\"answerdotai/answerai-colbert-small-v1\")
      5 my_documents = [
      6     get_wikipedia_page(\"Hayao_Miyazaki\"),
      7     get_wikipedia_page(\"Studio_Ghibli\"),
      8 ]
----> 9 index_path = RAG.index(index_name=\"my_index\", collection=my_documents)

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/ragatouille/RAGPretrainedModel.py:211, in RAGPretrainedModel.index(self, collection, document_ids, document_metadatas, index_name, overwrite_index, max_document_length, split_documents, document_splitter_fn, preprocessing_fn, bsize, use_faiss)
    202     document_splitter_fn = None
    203 collection, pid_docid_map, docid_metadata_map = self._process_corpus(
    204     collection,
    205     document_ids,
   (...)
    209     max_document_length,
    210 )
--> 211 return self.model.index(
    212     collection,
    213     pid_docid_map=pid_docid_map,
    214     docid_metadata_map=docid_metadata_map,
    215     index_name=index_name,
    216     max_document_length=max_document_length,
    217     overwrite=overwrite_index,
    218     bsize=bsize,
    219     use_faiss=use_faiss,
    220 )

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/ragatouille/models/colbert.py:338, in ColBERT.index(self, collection, pid_docid_map, docid_metadata_map, index_name, max_document_length, overwrite, bsize, use_faiss)
    334     self.docid_pid_map[docid].append(pid)
    336 self.docid_metadata_map = docid_metadata_map
--> 338 self.model_index = ModelIndexFactory.construct(
    339     \"PLAID\",
    340     self.config,
    341     self.checkpoint,
    342     self.collection,
    343     self.index_name,
    344     overwrite,
    345     verbose=self.verbose != 0,
    346     bsize=bsize,
    347     use_faiss=use_faiss,
    348 )
    349 self.config = self.model_index.config
    350 self._save_index_metadata()

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/ragatouille/models/index.py:485, in ModelIndexFactory.construct(index_type, config, checkpoint, collection, index_name, overwrite, verbose, **kwargs)
    482 if index_type == \"auto\":
    483     # NOTE: For now only PLAID indexes are supported.
    484     index_type = \"PLAID\"
--> 485 return ModelIndexFactory._MODEL_INDEX_BY_NAME[
    486     ModelIndexFactory._raise_if_invalid_index_type(index_type)
    487 ].construct(
    488     config, checkpoint, collection, index_name, overwrite, verbose, **kwargs
    489 )

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/ragatouille/models/index.py:150, in PLAIDModelIndex.construct(config, checkpoint, collection, index_name, overwrite, verbose, **kwargs)
    140 @staticmethod
    141 def construct(
    142     config: ColBERTConfig,
   (...)
    148     **kwargs,
    149 ) -> \"PLAIDModelIndex\":
--> 150     return PLAIDModelIndex(config).build(
    151         checkpoint, collection, index_name, overwrite, verbose, **kwargs
    152     )

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/ragatouille/models/index.py:254, in PLAIDModelIndex.build(self, checkpoint, collection, index_name, overwrite, verbose, **kwargs)
    248     indexer = Indexer(
    249         checkpoint=checkpoint,
    250         config=self.config,
    251         verbose=verbose,
    252     )
    253     indexer.configure(avoid_fork_if_possible=True)
--> 254     indexer.index(name=index_name, collection=collection, overwrite=overwrite)
    256 return self

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/indexer.py:80, in Indexer.index(self, name, collection, overwrite)
     77     self.erase()
     79 if index_does_not_exist or overwrite != 'reuse':
---> 80     self.__launch(collection)
     82 return self.index_path

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/indexer.py:89, in Indexer.__launch(self, collection)
     87     shared_queues = []
     88     shared_lists = []
---> 89     launcher.launch_without_fork(self.config, collection, shared_lists, shared_queues, self.verbose)
     91     return
     93 manager = mp.Manager()

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/infra/launcher.py:93, in Launcher.launch_without_fork(self, custom_config, *args)
     90 assert (custom_config.avoid_fork_if_possible or self.run_config.avoid_fork_if_possible)
     92 new_config = type(custom_config).from_existing(custom_config, self.run_config, RunConfig(rank=0))
---> 93 return_val = run_process_without_mp(self.callee, new_config, *args)
     95 return return_val

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/infra/launcher.py:109, in run_process_without_mp(callee, config, *args)
    106 os.environ[\"CUDA_VISIBLE_DEVICES\"] = ','.join(map(str, config.gpus_[:config.nranks]))
    108 with Run().context(config, inherit_config=False):
--> 109     return_val = callee(config, *args)
    110     torch.cuda.empty_cache()
    111     return return_val

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/indexing/collection_indexer.py:33, in encode(config, collection, shared_lists, shared_queues, verbose)
     31 def encode(config, collection, shared_lists, shared_queues, verbose: int = 3):
     32     encoder = CollectionIndexer(config=config, collection=collection, verbose=verbose)
---> 33     encoder.run(shared_lists)

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/indexing/collection_indexer.py:72, in CollectionIndexer.run(self, shared_lists)
     69 distributed.barrier(self.rank)
     70 print_memory_stats(f'RANK:{self.rank}')
---> 72 self.index() # Encodes and saves all tokens into residuals
     73 distributed.barrier(self.rank)
     74 print_memory_stats(f'RANK:{self.rank}')

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/indexing/collection_indexer.py:375, in CollectionIndexer.index(self)
    371 if self.verbose > 1:
    372     Run().print_main(f\"#> Saving chunk {chunk_idx}: \\t {len(passages):,} passages \"
    373                     f\"and {embs.size(0):,} embeddings. From #{offset:,} onward.\")
--> 375 self.saver.save_chunk(chunk_idx, offset, embs, doclens) # offset = first passage index in chunk
    376 del embs, doclens

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/indexing/index_saver.py:71, in IndexSaver.save_chunk(self, chunk_idx, offset, embs, doclens)
     70 def save_chunk(self, chunk_idx, offset, embs, doclens):
---> 71     compressed_embs = self.codec.compress(embs)
     73     self.saver_queue.put((chunk_idx, offset, compressed_embs, doclens))

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/indexing/codecs/residual.py:179, in ResidualCodec.compress(self, embs)
    176     residuals_ = (batch - centroids_)
    178     codes.append(codes_.cpu())
--> 179     residuals.append(self.binarize(residuals_).cpu())
    181 codes = torch.cat(codes)
    182 residuals = torch.cat(residuals)

File /opt/homebrew/Caskroom/miniconda/base/envs/ragatouille/lib/python3.11/site-packages/colbert/indexing/codecs/residual.py:193, in ResidualCodec.binarize(self, residuals)
    190 residuals = residuals & 1  # apply mod 2 to binarize
    192 assert self.dim % 8 == 0
--> 193 assert self.dim % (self.nbits * 8) == 0, (self.dim, self.nbits)
    195 if self.use_gpu:
    196     residuals_packed = ResidualCodec.packbits(residuals.contiguous().flatten())

AssertionError: (96, 8)"
}

I don't get this error when running colbert-ir/colbertv2.0" instead.

bclavie commented 3 months ago

Whoops well spotted! This is due to it occasionally defaulting to 8bits for small collections, and this check asserts that dim % (nbits * 8) is 0, which isn't the case for minicolbert as it's got a smaller embed dim (96 rather than 128). This was superfluous anyway as tests tend to show there isn't a sizeable gain from 8bit, so I've quickly removed this check! Small collections (sub 10k) compress to 4bit by default, and anything above to 2bits. This should therefore be fixed in version 0.0.8post3

MitchProbst commented 2 months ago

Whoops well spotted! This is due to it occasionally defaulting to 8bits for small collections, and this check asserts that dim % (nbits * 8) is 0, which isn't the case for minicolbert as it's got a smaller embed dim (96 rather than 128). This was superfluous anyway as tests tend to show there isn't a sizeable gain from 8bit, so I've quickly removed this check! Small collections (sub 10k) compress to 4bit by default, and anything above to 2bits. This should therefore be fixed in version 0.0.8post3

Thank you for the quick fix here! We had hit the same issue and updating to 0.0.8post3 worked great for us!