Open vblagoje opened 3 years ago
That’s interesting thanks, let’s see what we can do. Can you detail your last sentence? I’m not sure I understand it well.
Hi ! I just re-ran a quick benchmark and using to_numpy()
seems to be faster now:
import pyarrow as pa # I used pyarrow 3.0.0
import numpy as np
n, max_length = 1_000, 512
low, high, size = 0, 2 << 16, (n, max_length)
table = pa.Table.from_pydict({
"input_ids": np.random.default_rng(42).integers(low=low, high=high, size=size).tolist()
})
%%timeit
_ = table.to_pandas()["input_ids"].to_numpy()
# 1.44 ms ± 80.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
_ = table["input_ids"].to_pandas().to_numpy()
# 461 µs ± 14.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%%timeit
_ = table["input_ids"].to_numpy()
# 317 µs ± 5.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Currently the conversion from arrow to numpy is done in the NumpyArrowExtractor here:
Let's update the NumpyArrowExtractor to call to_numpy
directly and see how our github benchmarks evolve ?__
Sounds like a plan @lhoestq If you create a PR I'll pick it up and try it out right away!
@lhoestq I can also prepare the PR, just lmk.
I’m not exactly sure how to read the graph but it seems that to_categorical take a lot of time here. Could you share more informations on the features/stats of your datasets so we could maybe design a synthetic datasets that looks more similar for debugging testing?
I created https://github.com/huggingface/datasets/pull/2505 if you want to play with it @vblagoje
I’m not exactly sure how to read the graph but it seems that to_categorical take a lot of time here. Could you share more informations on the features/stats of your datasets so we could maybe design a synthetic datasets that looks more similar for debugging testing?
@thomwolf starting from the top, each rectangle represents the cumulative amount of it takes to execute the method call. Therefore, format_batch in torch_formatter.py takes ~20 sec, and the largest portion of that call is taken by to_pandas call and the smaller portion (grey rectangle) by the other method invocation(s) in format_batch (series_to_numpy etc).
Features of the dataset are BERT pre-training model input columns i.e:
f = Features({
"input_ids": Sequence(feature=Value(dtype="int32")),
"attention_mask": Sequence(feature=Value(dtype="int8")),
"token_type_ids": Sequence(feature=Value(dtype="int8")),
"labels": Sequence(feature=Value(dtype="int32")),
"next_sentence_label": Value(dtype="int8")
})
I'll work with @lhoestq till we get to the bottom of this one.
@lhoestq the proposed branch is faster, but overall training speedup is a few percentage points. I couldn't figure out how to include the GitHub branch into setup.py, so I couldn't start NVidia optimized Docker-based pre-training run. But on bare metal, there is a slight improvement. I'll do some more performance traces.
Hi @vblagoje, to install Datasets from @lhoestq PR reference #2505, you can use:
pip install git+ssh://git@github.com/huggingface/datasets.git@refs/pull/2505/head#egg=datasets
Hey @albertvillanova yes thank you, I am aware, I can easily pull it from a terminal command line but then I can't automate docker image builds as dependencies are picked up from setup.py and for some reason setup.py doesn't accept this string format.
@vblagoje in that case, you can add this to your setup.py
:
install_requires=[
"datasets @ git+ssh://git@github.com/huggingface/datasets.git@refs/pull/2505/head",
@lhoestq @thomwolf @albertvillanova The new approach is definitely faster, dataloader now takes less than 3% cumulative time (pink rectangle two rectangles to the right of tensor.py backward invocation)
When we drill down into dataloader next invocation we get:
And finally format_batch:
Not sure this could be further improved but this is definitely a decent step forward.
datasets @ git+ssh://git@github.com/huggingface/datasets.git@refs/pull/2505/head
@albertvillanova how would I replace datasets dependency in https://github.com/huggingface/transformers/blob/master/setup.py as the above approach is not working.
@vblagoje I tested my proposed approach before posting it here and it worked for me.
Is it not working in your case because of the SSH protocol? In that case you could try the same approach but using HTTPS:
"datasets @ git+https://github.com/huggingface/datasets.git@refs/pull/2505/head",
Also note the blanks before and after the @
.
@albertvillanova of course it works. Apologies. I needed to change datasets in all deps references , like here for example.
Is time spent casting an issue here? See https://github.com/huggingface/datasets/issues/4676 that Datasets can spend huge amounts of time repeatedly casting to Python objects.
Is your feature request related to a problem? Please describe. It would be great, if possible, to further improve read performance of raw encoded datasets and their subsequent conversion to torch tensors.
A bit more background. I am working on LM pre-training using HF ecosystem. We use encoded HF Wikipedia and BookCorpus datasets. The training machines are similar to DGX-1 workstations. We use HF trainer torch.distributed training approach on a single machine with 8 GPUs.
The current performance is about 30% slower than NVidia optimized BERT examples baseline. Quite a bit of customized code and training loop tricks were used to achieve the baseline performance. It would be great to achieve the same performance while using nothing more than off the shelf HF ecosystem. Perhaps, in the future, with @stas00 work on deepspeed integration, it could even be exceeded.
Describe the solution you'd like Using profiling tools we've observed that appx. 25% of cumulative run time is spent on data loader next call.
As you can observe most of the data loader next call is spent in HF datasets torch_formatter.py format_batch call.
Digging a bit deeper into format_batch we can see the following profiler data:
Once again, a lot of time is spent in pyarrow table conversion to pandas which seems like an intermediary step. Offline @lhoestq told me that this approach was, for some unknown reason, faster than direct to numpy conversion.
Describe alternatives you've considered I am not familiar with pyarrow and have not yet considered the alternatives to the current approach.
Most of the online advice around data loader performance improvements revolve around increasing number of workers, using pin memory for copying tensors from host device to gpus but we've already tried these avenues without much performance improvement. Weights & Biases dashboard for the pre-training task reports CPU utilization of ~ 10%, GPUs are completely saturated (GPU utilization is above 95% on all GPUs), while disk utilization is above 90%.