StonyBrookNLP / ircot

Repository for Interleaving Retrieval with Chain-of-Thought Reasoning for Knowledge-Intensive Multi-Step Questions, ACL23
https://arxiv.org/abs/2212.10509
Apache License 2.0
173 stars 21 forks source link

address.jsonnet file format and CUDA error #5

Closed minjunp closed 1 year ago

minjunp commented 1 year ago

Hi,

I'm trying to reproduce the results, and I found llm_server_address.jsonnet and retriever_address.jsonnet necessary. Can you provide an example scripts for these?

Also, I'm getting torch.cuda.OutOfMemoryError: CUDA out of memory error message. If you can give me some tips to prevent cuda error (e.g. where to reduce the batch size), that would be appreciated.

Thank you in advance :)

HarshTrivedi commented 1 year ago

Oops, it looks like those files didn't get added as they were in my .gitignore. I've added them now.

Regarding OOM: There is no batching happening anywhere in the inference. It's all one instance at a time. To reduce memory usage, however, there are two things you can do.

  1. Instead of default flan-t5 models, use the bf16 versions. Let's say you want to use flan-t5-xxl-bf16 instead of flan-t5-xxl, you'll need to change the occurrences of the former with later in the experiment config of your choice. E.g., for IRCoT QA Flan-T5-XXL MuSiQue run, you'll make that change in this file. Look at the file/config names in this folder and it should be clear. We did our T5-Flan-XXL experiments using A100 (80Gs) and the rest with A6000 (48Gs) without BF16. If you use BM16, you can do all experiments using 48Gs. From a few experiments I tried, using Huggingface's bf16 versions gives the same performance, but I haven't made an exhaustive comparison.
  2. If you still can't fit it in your GPU memory, you can reduce the max number of token allowance for the model. You can do this by changing the model_tokens_limit in the experiment config (e.g., this one) from 6000 to a lower number. This will have some impact on the performance, but it may not be large depending on how much you have to reduce the context.
minjunp commented 1 year ago

Thank you for quick & detailed explanation.

I'm new to FastAPI, and for the jsonnet files you provided, can I use http://localhost when you are running on the server? I'm running the code on my GCP, but I wasn't sure if I'm properly running FastAPI.

When I run the script below:

./reproduce.sh ircot flan-t5-xxl hotpotqa

I get a message as follows:

Token indices sequence length is longer than the specified maximum sequence length for this model (558 > 512). Running this sequence through the model will result in indexing errors
Running inference on examples
0it [00:00, ?it/s]Post request didn't succeed. Will wait 20s and retry.
Post request didn't succeed. Will wait 20s and retry.

and the message repeats.

HarshTrivedi commented 1 year ago

You can ignore Token indices sequence length is longer than the specified maximum sequence length for this model (558 > 512)., it's coming from HF.

The Post request didn't succeed. Will wait 20s and retry. means that your client cannot connect to the server. The client may not be able to connect for many reasons. So try putting a breakpoint at that point and see response = requests.post(url, json=params) ; print(response.text) gives you. Feel free to post it here again if you need help.

minjunp commented 1 year ago

Thank you for the answer,

The issue I'm having is with the retriever server. I'm able to access localhost:8000 which returns

{
  "message": "Hello! This is a retriever server."
}

However, when I run predict.py code, it seems the code is post requesting tolocalhost:8000/retrieve which says:

{
  "detail": "Method Not Allowed"
}

I'm running your predict.py and I'm getting the same error message (Post request didn't succeed. Will wait 20s and retry.) I'm getting is from ircot.py. I think I should not get "Method Not Allowed."

HarshTrivedi commented 1 year ago

Can you confirm Method Not Allowed message is not obtained by visiting the browser at localhost:8000/retrieve and is instead obtained by putting a breakpoint in this line and seeing response.text? The method not allowed says there is no available route to the requested path (/retrieve) and method (get, post, etc). If you visit it on the browser, it's a GET request whereas the server expects a POST request. Let me know if this was already based on response.text and I can dig into it later what's going on.

Also, can you also confirm that your Elasticsearch server is running fine and you've already run the indexing scripts? You can check it by running curl localhost:9200/_cat/indices. It should show different indices and their sizes (which should match up to what's given in the readme, but the exact size wouldn't be a failure).

minjunp commented 1 year ago

It seems I'm having issue with elasticsearch. I cannot access port 9200 on GCP. Let me resolve this issue and come back if the issue persists. Thank you again for your prompt responses.