dki-lab / Pangu

Code for reproducing the ACL'23 paper: Don't Generate, Discriminate: A Proposal for Grounding Language Models to Real-World Environments
https://arxiv.org/abs/2212.09736
69 stars 9 forks source link

Compute Resources. #6

Open PrayushiFaldu opened 1 year ago

PrayushiFaldu commented 1 year ago

I am training the t5 variant of model (grail_train_t5.jsonnet). The config says

I am running on an A100 with 40 GB memory. It shows me a total training time of around 120 hours (with around 70 seconds per step).

Is this the expected training time? Is there any possibility of optimization? Can I increase batch size (will it affect performance)?

Edit : It shows 120 hours for one epoch (and not all). Am I making some mistake?

AiRyunn commented 1 year ago

I'm having the same issue about training and inference time. When I was training with BERT on GrailQA, it took 30+ hrs for training and 20+ hrs for inference (for validation). Even using cached graph queries in the later epochs, it doesn't seem to save much time. The author told me that the total training time would take 3-5 days, but in my case it still took 1 or 2 weeks.

entslscheia commented 1 year ago

@AiRyunn I apologize for any inconvenience related to the code. Have you tried distributed training (see #7)? For experiments with GrailQA, I actually used 4 GPU cards to accelerate training. In addition, you can use a smaller patience number for early stopping when working with GrailQA. Most of my experiments on GrailQA converged at the 1st or 2nd epoch, which allowed me to complete training in less than 3 days.

Using only one single card for GrailQA might result in a training time over 30hrs, however, the 20hrs validation time during training seems rather abnormal. Based on my experiments log of BERT-base on GrailQA, when using 4 cards, it only took around 1 hr for validation during training (see the screenshot below).

image

I'd like to have some further discussion with you to help you resolve the issue. Could you double check whether the cache file is properly written to your disk? Meanwhile, I am also doing some test to see whether there's any bug that would lead to the speed issue in the uploaded version of my code

entslscheia commented 1 year ago

@PrayushiFaldu I believe the 120 hours might be an initial estimate at the early phase of the training process. As training progresses, a greater number of SPARQL queries get cached, which significantly enhances the speed. In my own experiments using the T5-base on GrailQA with 4 A6000 48G cards, the training duration reduced to ~10 hours in the later epochs. If you're using a single card, it will be approximately 20 hours.

image
entslscheia commented 1 year ago

Based on community feedback, it appears that training time is a prevalent concern. I'll thoroughly investigate this matter and provide clarity on all potential scenarios.

Appreciate your patience!

entslscheia commented 1 year ago

I've added a section in README to address concerns related to training time. Additionally, I've uploaded the cache file in that section. Hope it helps.

PrayushiFaldu commented 1 year ago

Thankyou @entslscheia for the help. Do you have any similar estimation for inference? It seems that one question is taking one minute and since there is no batching done during inference 10k+ questions will take very long time.

Please correct if I am making any mistakes in above estimates.

AiRyunn commented 1 year ago

@ysu1989 Thank you for your clarification. I'm trying distributed training now and this should take some time. I had the same experience with @PrayushiFaldu. At inference, some questions may take 1-3 mins, while others take only a few secs, so it often gets stuck on some questions. This makes the overall inference time very long. I guess this is because the code performs a full graph query at inference rather than sampling relations of nodes at training, which makes it difficult to cache.

entslscheia commented 1 year ago

Thankyou @entslscheia for the help. Do you have any similar estimation for inference? It seems that one question is taking one minute and since there is no batching done during inference 10k+ questions will take very long time.

Please correct if I am making any mistakes in above estimates.

Hi,

Have you tried to do prediction with the cache file I just uploaded? Also Q2 in README might partly answer your question.

entslscheia commented 1 year ago

@ysu1989 Thank you for your clarification. I'm trying distributed training now and this should take some time. I had the same experience with @PrayushiFaldu. At inference, some questions may take 1-3 mins, while others take only a few secs, so it often gets stuck on some questions. This makes the overall inference time very long. I guess this is because the code performs a full graph query at inference rather than sampling relations of nodes at training, which makes it difficult to cache.

I think it may not be very accurate to say that we "sample" relations during training.

To have a better sense about the difference between training and inference. We can look at an example For example, say the current plan P is denoted as (JOIN relation_b (JOIN relation_a entity0)), on the one hand, during training, we will use the execution of P(e.g., a set of entities) to query the KG and find the set of viable relations to expand. Specifically, for each entity in P, we query the KG to find relations that are reachable from it, and finally we union the relations over all entities in P. However, sometimes, P can be very large (e.g., thousands of entities). As a result, during training we only use the first 100 entities in P to find viable relations. This in a sense is an approximate algorithm, and you can see that during training we only execute one-hop queries over the KG in this example (i.e., query the relations of one entity), which are much faster than multi-hop queries.

On the other hand, during inference, you can essentially replicate the training process. This approximation doesn't significantly affect the final metrics; empirically, there's only about a 1-2% deviation in F1 score. However, to maximize Pangu's performance and showcase the best possible results, we opted for an exact query method to identify relations during inference (also explained in Q2 README). Basically, for P, instead of using its execution, we directly query the KG to find relations that are reachable from the subgraph derived from (JOIN relation_b (JOIN relation_a entity0)), which will be a multi-hop query over the KG. However, this is not related to caching; there is no extra difficulty in caching multi-hop queries compared with one-hop queries during training. As long as the queries are being cached, you can achieve the same level of acceleration.

Also, the cache file I uploaded should already include some of these multi-hop queries triggered on GrailQA, but they may not be complete.

PrayushiFaldu commented 1 year ago

@ysu1989 Thank you for your clarification. I'm trying distributed training now and this should take some time. I had the same experience with @PrayushiFaldu. At inference, some questions may take 1-3 mins, while others take only a few secs, so it often gets stuck on some questions. This makes the overall inference time very long. I guess this is because the code performs a full graph query at inference rather than sampling relations of nodes at training, which makes it difficult to cache.

@AiRyunn I hope you are using GPU while performing inference. I was missing the extra argument of cuda-device earlier and now the code is much faster than before.