Open jmpanfil opened 2 years ago
There are many pieces in work here: SparkDatasetConverter would store spark dataframe into a parquet file, then a make_torch_dataloader would bring up workers pool to read that parquet file. I assume the slowness you are referring to is coming from next
call? Can you confirm please?
make_torch_dataloader
takes petastorm_reader_kwargs
argument. You can see full documentation in make_batch_reader
. You can try tweaking some parameters there (reader_pool_type, workers_count) to play with various parallelization parameters (thread vs process pool, number of workers).
Hope this helps.
Hi thanks for your help! My main concern is that using torch.stack
with every next
call is inefficient, and I'm missing an obvious way to use the SparkDatasetConverter
that doesn't require calling stack
. That's why I tried creating an array column in my dataframe first, but that turned out to be slower.
I will dive into the full documentation that you sent and play around with some parameters.
@jmpanfil how did your experimentation with parameters go?
I've been working on using petastorm to train PyTorch models from spark dataframes (somewhat following this guide). I'm curious if there are any ways I can speed up data loading.
Here's a basic overview of my current flow.
df_train
is a spark dataframe with three columns: x1 (float), x2 (binary 0,1), y (float). I'm using pyspark.My concern is that the
torch
operations might not be optimal. Something else I tried was first creating an array column in my spark dataframe for x1 and x2. I was surprised to find that each epoch was more than 2 times slower than the above strategy.Are there any improvements I can make here?