Closed M-J-Murray closed 1 year ago
When I run a profiler it seems there is a bottleneck in the apply_rotary_pos_emb
function called from the ESM sequence embedding. It could be that this is such a large bottle neck, that changing the batch size almost has no effect.
cc @Rocketknight1
Hi
It would be good to provide a full code snippet. Currently, it is not super clear that if you are also including the time spent on the tokenization (I guess that's not the case however). And in any case, with a code snippet, it's easier for us to help.
Thank you in advance.
The following code was run on a Tesla V100-SXM2-16GB
.
With a batch size of 10, it executes in 14.96 seconds.
With a batch size of 50, it executes in 13.6 seconds.
I would expect a much larger change in execution time between the two batch sizes. I would have thought a batch size of 50 would execute five times faster than a batch size of 10.
import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import EsmForMaskedLM, EsmTokenizer
import time
device = torch.device("cuda")
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D")
model.eval()
model = model.to(device)
batch_size = 10
samples = 500
sequence_length = 250
tokens = list("ARNDCQEGHILKMFPSTWYV")
sequences = ["".join(np.random.choice(tokens, sequence_length)) for _ in range(samples)]
t0 = time.time()
with torch.no_grad():
for batch_seqs in DataLoader(sequences, batch_size=batch_size):
inputs = tokenizer(batch_seqs, return_tensors="pt")
inputs = inputs.to(device)
model.base_model(**inputs)
print(f"Execution time: {time.time() - t0}")
Thanks for the code snippet 🤗
Hi @M-J-Murray, I think there are a few confounding things here. Firstly, the ESM tokenizer is relatively unoptimized. This means tokenization takes longer than it does for other models. If performance is critical, I would strongly recommend tokenizing your sequences once, then saving the tokenized outputs, rather than tokenizing them on-the-fly in each loop. This applies to all models, but especially to ESM!
Secondly, performance does not scale linearly with batch size. The same amount of computation has to be done for 100 batches of 2, or 2 batches of 100. The main reason to use larger batch sizes is that larger batches generally allow GPUs to do more work in parallel, which is helpful when the model is small, as small batch sizes in small models generally cannot use all the power of a high-end GPU at once. There is also the further benefit during training that fewer optimizer update steps are needed, but this does not apply when you're doing inference.
In this case, though, the model has 650M parameters, which is reasonably large. I would guess that even smaller batch sizes are enough to saturate a V100 GPU for a model of this size, so the performance benefit of larger batches would not be that significant. I think this, combined with the additional constant time added to your measurements from running the tokenizer in the loop, is enough to explain the lack of benefit, and the model is actually working as expected!
@Rocketknight1 Thank you, I've just validated and it does seem the tokenizer is the main bottle neck here. I will write my own tokenizer for now.
System Info
Transformers version: 4.30.2 Python version: 3.9.16 This occurs on both: MacBook Pro M2: MacOS 13.2.1 (22D68), ran using mps AND Debian 4.19.171-2 x86_64 GNU/Linux, ran using gpu
Who can help?
@ArthurZucker and @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I'm using the EsmForMaskedLM model with
facebook/esm2_t33_650M_UR50D
along with the EsmTokenizer. If I run inference on 200 sequences, it takes the same amount of time to run 10 forward passes with a batch size of 20, vs 100 forward passes on a batch size of 2. This seems to indicate the model doesn't support batch processing under the hood? It seems strange that the interface would imply that it supports batch processing, without actually supporting it properly.Expected behavior
I would expect running 10 forward passes to be much faster than 100 forward passes.