qqaatw / pytorch-realm-orqa

PyTorch reimplementation of REALM and ORQA
Apache License 2.0
22 stars 2 forks source link

run_finetune takes too much time. #8

Open Boltzmachine opened 2 years ago

Boltzmachine commented 2 years ago

Under the default settings, one epoch cost 700+h on one Tesla V100, which makes training impossible. I found the retriever will select top 5000 docs and tokenize them, which takes up most of the time. How to fix it? Can I make beam_size smaller?

qqaatw commented 2 years ago

For now you'll need to install transformers from source as the patch that by default uses the fast tokenizer largely improving the retrieval speed hasn't been included in the latest release of transformers.

Boltzmachine commented 2 years ago

I installed from source, but it still need 50+h

qqaatw commented 2 years ago

Yes then you can try to reduce the beam size. This would probably result in an outcome not as good as the paper.

Boltzmachine commented 2 years ago

I found in the paper the number of candidate documents is set to 8.

For each example, we retrieve and marginalize over 8 candidate documents, including the null document

why do you set it to 5000 in the code.

qqaatw commented 2 years ago

Fine-tuning and pre-training of REALM are totally different things. This repository only subsumes the fine-tuning part, where the details are located in ORQA paper (see readme).

Also, all the hyperparameters this repository uses, e.g. beam-size, are the default configuration of the paper to fully reproduce the results.

And just for you to know, if you are curious about why the pre-training part wasn't migrated here, check out other issues.

Boltzmachine commented 2 years ago

Thanks. Does 5k runs slowly too on your computer (as far as I know the original implementation is TensorFlow)? If it does. Do you mind me optimizing the code and proposing a PR?

qqaatw commented 2 years ago

Yes, fine-tuning on Natural Questions using the default configuration also took around 2 days on single 2080Ti GPU, which is acceptable for me.

The bottleneck of the training speed is that the original TF implementation leverages custom C++ ops to perform all heavy retrieval work, while our PyTorch implementation is built on sole Python except for the fast tokenizer, which uses a Rust backend.

Therefore, unless we implement an efficient backed (like their c++ ops) for the retrieval, the training speed would not be improved too much I believe. Anyway, PR is always welcomed.

Boltzmachine commented 2 years ago

Is it possible to pre-tokenize all the documents and just concat it in forward function?

qqaatw commented 2 years ago
    import time
    from transformers import RealmTokenizerFast

    tokenzier = RealmTokenizerFast.from_pretrained("google/realm-orqa-nq-openqa")
    text = ["What is the previous name of Meta Platform, Inc.?"
        for i in range(5000)
    ]
    text_pair = [
        "Meta Platforms, Inc., doing business as Meta and formerly known as Facebook, Inc., is an American multinational technology conglomerate based in Menlo Park, California. The company is the parent organization of Facebook, Instagram, and WhatsApp, among other subsidiaries. Meta is one of the world's most valuable companies. It is one of the Big Five American information technology companies, alongside Google (Alphabet Inc.), Amazon, Apple, and Microsoft" 
        for i in range(5000)
    ]

    start = time.time()
    ids = tokenzier(text, text_pair, padding=True, truncation=True, return_special_tokens_mask=True, max_length=512)

    print("Elapsed time:", time.time() - start) 

    # Elapsed time: 0.29874181747436523

Very clear, the bottleneck isn't in the tokenization. Instead, finding where the positions of answer ids are in the retrieved documents takes much time, which is more than O(n^2).

That's why they wrote the custom ops.

Boltzmachine commented 2 years ago

on my computer, the tokenization process is also a bottleneck. I run line_profiler. image

Anyway , I will also check the find answer process.

qqaatw commented 2 years ago

OK, If we cumulate the time of tokenization on Natural Questions dataset, 79K steps 2 epochs 0.3s / 3600 = 13.16 hours, it indeed takes appreciable time. However, just concatenating question and context doesn't work because padding and truncation still need to be performed to make the entire beam aligned,

Anyway, If pre-tokenization + padding / truncation can provide a significant improvement, that would be great.