lancedb / lance

Modern columnar data format for ML and LLMs implemented in Rust. Convert from parquet in 2 lines of code for 100x faster random access, vector index, and data versioning. Compatible with Pandas, DuckDB, Polars, Pyarrow, with more integrations coming..
https://lancedb.github.io/lance/
Apache License 2.0
3.8k stars 211 forks source link

lance.torch.data.LanceDataset + torch.utils.data.DataLoader is slower than to_batches + frombuffer #2804

Open jacketsj opened 3 weeks ago

jacketsj commented 3 weeks ago

I was writing some gpu code that iterates through an entire dataset, and I found that the torch integration was quite slow. I wrote up a simple to_batches-based implementation of what I was working on and found it to be much faster.
It is quite likely there are some cases my implementation does not handle, but at least in this case there is significant performance to be gained.

Here's some code that reproduces the issue (tested on lance 0.16.1, due to #2803):

import lance
import pyarrow as pa
import pyarrow.compute as pc
import time
import lance.torch.data as ltd
import torch
from torch.utils.data import DataLoader

dims = 128
nrows = 10_000_000

def next_batch(batch_size, offset):
    values = pc.random(dims * batch_size).cast('float32')
    return pa.table({
        'id': pa.array([offset + j for j in range(batch_size)]),
        'vector': pa.FixedSizeListArray.from_arrays(values, dims),
    }).to_batches()[0]

def batch_iter(num_rows):
    i = 0
    while i < num_rows:
        batch_size = min(10_000, num_rows - i)
        yield next_batch(batch_size, i)
        i += batch_size

schema = next_batch(1, 0).schema

ds_path = "./temp-test-torch.lance"

ds = lance.write_dataset(batch_iter(nrows), ds_path, schema=schema, mode="overwrite")

# Torch integration
tdataset = ltd.LanceDataset(
    ds,
    columns=["vector"],
    batch_size=1024,
    batch_readahead=8,
    with_row_id=True,
)

start_time = time.time()

tdataloader = DataLoader(tdataset)
for tbatch in tdataloader:
    tvecs = tbatch["vector"]
    trow_ids = tbatch["_rowid"]

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Torch integration elapsed time: {elapsed_time:.6f} seconds")

# Manual to_batches and frombuffer
start_time = time.time()

for batch in ds.to_batches(
        columns=["vector"],
        batch_size=1024,
        batch_readahead=8, 
        with_row_id=True,
):
    batch_vecs = batch["vector"]
    vecs = torch.frombuffer(batch_vecs.buffers()[2], dtype=torch.float32)
    vecs = vecs.view(len(batch_vecs), 128)
    batch_row_ids = batch["_rowid"]
    row_ids = torch.frombuffer(batch_row_ids.buffers()[1], dtype=torch.int64)
    row_ids = row_ids.view(len(batch_row_ids), 1).squeeze()

end_time = time.time()
elapsed_time = end_time - start_time
print(f"Manual ver elapsed time: {elapsed_time:.6f} seconds")

# Verify that this actually produces the same data
max_diff = 0

tdataloader = DataLoader(tdataset)
for tbatch, batch in zip(tdataloader, ds.to_batches(
        columns=["vector"],
        batch_size=1024,
        batch_readahead=8,
        with_row_id=True,
)):
    tvecs = tbatch["vector"]
    trow_ids = tbatch["_rowid"]

    batch_vecs = batch["vector"]
    vecs = torch.frombuffer(batch_vecs.buffers()[2], dtype=torch.float32)
    vecs = vecs.view(len(batch_vecs), 128)
    batch_row_ids = batch["_rowid"]
    row_ids = torch.frombuffer(batch_row_ids.buffers()[1], dtype=torch.int64)
    row_ids = row_ids.view(len(batch_row_ids), 1).squeeze()

    max_diff = max(torch.max(torch.abs(tvecs-vecs)).item(), max_diff)
    max_diff = max(torch.max(torch.abs(trow_ids-row_ids)).item(), max_diff)

print(f"Should be 0: {max_diff}")

Output (warnings removed):

Torch integration elapsed time: 6.260330 seconds
Manual ver elapsed time: 0.332647 seconds
Should be 0: 0
tonyf commented 3 weeks ago

This difference here is likely due to the fact your test is reading the dataset in order without any sharding. This means you're getting the added benefit of batch and fragment read aheads.

By default, the LanceDataset uses ShardedFragmentSampler which has no fragment level read ahead. If you set shard_granularity="batch" you'll probably get the same performance under your test.

However, in a real world setting where shuffle=True, dataloader num_workers>1 and torch dist workers > 1, the fragment sampler will be faster.

tonyf commented 3 weeks ago

If you want to compare, I have an experimental fragment-level sharded sampler with read-ahead here: https://gist.github.com/tonyf/7087dd3130ee5df1e93b862d55230f1c

And, a row-id based sharded sampler: https://gist.github.com/tonyf/d512e26183d97eb4fbae9c0b6abe5072 which is more strictly correct under sharded scenarios though it is not as fast

more on all of this here: https://github.com/lancedb/lance/discussions/2781

jacketsj commented 3 weeks ago

Thanks for the idea, although I'm currently convinced that it's from this function, which copies data into numpy as an intermediary (at least in 0.16.1). That's similar to what I was doing with to_batches before using frombuffer, which obtained near-identical performance to the torch integration above. There has been a significant change to that function since 0.16.1, so I need to fix #2803 (likely caused by that same change). These are not urgent atm, but should hopefully be quick to fix once I get to them.