StonyBrookNLP / ircot

Repository for Interleaving Retrieval with Chain-of-Thought Reasoning for Knowledge-Intensive Multi-Step Questions, ACL23
https://arxiv.org/abs/2212.10509
Apache License 2.0
154 stars 20 forks source link

Where and How is the reason-step implemented? #12

Closed kerkathy closed 11 months ago

kerkathy commented 11 months ago

Hi,

I really appreciate your work and the delicately structured code!

In the paper, you mentioned that the reason-step generates next CoT sentence based on

  1. the question
  2. so far retrieved paragraphs, and
  3. CoT sentences

I wonder how are the three components combined? Did you simply concate the three, which means sth like concat(question, paragraph_1, paragraph_2, ..., CoT_sent_1, CoT_sent_2, ...)? Where does this part locate in the code?

I tried to looked up and it seems that you fetched the retrieved paragraphs in read_examples() in dataset_readers.py, where output_instance is returned as a list of dictionaries containing all relevant information for each paragraph. And in inference_mode in configurable_inference.py somehow the whole reasoning and answering is finished. What happened here?

Also, I want to make sure that in this implementation, the unit of indexing / retrieval is the whole paragraph for a document, right? That means for each wikipedia article, we only have one entry in the database, instead of separating it into smaller chunks / passages.

Please feel free to correct me on any misunderstanding of mine. Thanks again for your effort šŸ˜Š

HarshTrivedi commented 11 months ago

Thank you for your interest!

I wonder how are the three components combined?

The components are combined in the prompt format given in the paper.

Wikipedia Title: <Page Title>
<Paragraph Text>

...

Wikipedia Title: <Page Title>
<Paragraph Text>

Q: <Question>
A: <CoT-Sent-1> <CoT-Sent-1> ... <CoT-Sent-n>
  1^          2^            3^                # <- not part of prompt, only for explanation

For demonstrations,

For test instance,

Where does this part locate in the code?

Most of the logic is implemented in commaqa/inference/ircot.py. But to navigate the code, you should start by looking at the relevant experiment config. Let's say base_config/ircot_codex_hotpotqa.jsonnet. Look at the models section in it. The control flow goes from one module (node) to another guided by "next_model" key. Each module in it is a python class defined in /inference/ircot.py. You can figure out which dict/node maps to which class by looking at its "name" field and finding the matching class in inference/constants.py, e.g., step_by_step_cot_gen maps to StepByStepCOTGenParticipant. The specific code you are looking for is also in StepByStepCOTGenParticipant.

The unit of indexing / retrieval is the whole paragraph ... That means for each wikipedia article, we only have one entry in the database

The unit of the document is indeed paragraphs. But note that there may or may not be 1 or more paragraphs for a wikipedia article. This depends on which dataset you're looking at. E.g., for HotpotQA, the official corpus only has 1st, i.e., only one paragraph per page, so that's what we used. But for IIRC, all paragraphs in the page are part of their corpus, so that's what we use. All the processing to create individual dataset corpuses are in given in processing_scripts/.

kerkathy commented 11 months ago

Hi, I appreciate your prompt and helpful answer! I followed your logic and get a much better understanding of the code. So just a few things to make sure:

  1. By inspecting ircot_codex_hotpotqa.jsonnet as well as class RetrieveAndResetParagraphsParticipant() , it seems that for the HotpotQA case, we perform bm25 retrieval by elasticsearch, and then discard paragraphs that are closely matching, and finally keep top-15 as selected paragraphs.
  2. We evaluate using the fullwiki setting of HotpotQA (instead of the distractor setting)
  3. Contexts, questions, and CoT sentences are combined in class StepByStepCOTGenParticipant() in ircot.py. Here the contexts are the 15 paragraphs from retrieval.
  4. I see a add_and_reorder_if_pinned function which triggers reranking. I wonder if that comes from your special handling of IIRC dataset, where "We always keep the main passage as part of the input to the model regardless of the retrieval strategy used."? If so, that means for the other datasets, paragraphs are not reranked before being joined into one context.

Am I understanding the flow correctly? Please correct me on any mistakes, thanks again!

HarshTrivedi commented 11 months ago
  1. After generating each CoT sentence, we retrieve K paragraphs using that sentence and append those paragraphs to the context list. If the paragraph was already part of the list (or a very close one was), we do not add it to the list. If the list reaches the cap of 15 paragraphs, we do not add the rest. For the next CoT generation, everything accumulated in the context list up to that point is shown as context (in the order it was added). Before the first CoT generation, we retrieve K paragraphs based on the question. K here is a hyperparameter. The retrieval happens by BM25.

  2. yes, correct.

  3. As mentioned above, 15 is the upper cap. In the initial steps, the context will be shorter.

  4. Correct. add_and_reorder_if_pinned is only relevant for IIRC. But it doesn't do any reranking. It just puts the "main_paragraph" (pinned one) to the top/bottom, keeping everything else the same.

kerkathy commented 11 months ago

Got it. Thank you for your awesome explanation. It's so generous and helpful of you with your time. Thanks again šŸ™‡ā€ā™‚ļø

kerkathy commented 11 months ago

Sorry for another question. šŸ˜æ

Since the max size of our paragraph list is capped to 15, does it mean that when we choose K=8 (to retrieve 8 paragraphs in each step), we can at most take two retrieval steps; while when we choose K=2, we can take at most 8 steps? Also, the paragraph list won't be updated anymore if it's full, i.e., we don't throw away the older ones and take in the newer ones, right?

Thanks again šŸ˜ŗ

HarshTrivedi commented 11 months ago

In our experiments: In each step we retrieve K paragraphs, and from those only append those to the context list that are new. Anything that overflows above 15 in any step would not be used, as context_list is set to context_list[:15] after appending. If the K is large and the capacity is filled up in 1-2 steps but CoT hasn't finished, then context remains same for the remaining steps. We only stop when CoT finishes, not when the context limit is reached. Such a K may not work the best and would be filtered out in HP tuning, perhaps, in favor of a smaller K.

Note that capacity of 15 and K=8 doesn't necessarily mean the capacity will be filled up in 2 steps, because subsequent retrievals may have overlap with previous ones, meaning that not all of them will be added to the context_list.

I am sure one can be smarter about what to keep and/or remove and by how much in each step. But at least some of the more complex variations we tried did not result in any improvements, and didn't justify added complexity.

kerkathy commented 11 months ago

I see. Thank you so much!