uber / petastorm

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.
Apache License 2.0
1.78k stars 285 forks source link

Spark Dataset Converter reset reader position does not work as expected #553

Open liangz1 opened 4 years ago

liangz1 commented 4 years ago

When using conv.make_torch_dataloader(num_epochs=1) as dataloader, the dataloader should support multiple calls of enumerate(dataloader). Use the following code snippet as an example, we define the expected behavior:

from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
from petastorm.spark import SparkDatasetConverter, make_spark_converter
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, 'file:///tmp/petastorm')
df=spark.range(10)
conv=make_spark_converter(df)
with conv.make_torch_dataloader(num_epochs=1, batch_size=32) as loader:
  for epoch in range(10):
    print(f"epoch: {epoch}")
    for step, batch in enumerate(loader):
      print(f"step: {step} batch length: {batch['id'].size(0)}")
    print()

Expected behavior:

epoch: 0
step: 0 batch length: 10

epoch: 1
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 10

epoch: 2
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 10

epoch: 3
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 10

epoch: 4
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 10

epoch: 5
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 10

epoch: 6
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 10

epoch: 7
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 10

epoch: 8
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 10

epoch: 9
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 10

Actual behavior:

epoch: 0
step: 0 batch length: 10

epoch: 1
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 20

epoch: 2
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 30

epoch: 3
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 32
step: 1 batch length: 8

epoch: 4
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 18

epoch: 5
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 28

epoch: 6
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 32
step: 1 batch length: 6

epoch: 7
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 16

epoch: 8
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 26

epoch: 9
Start a new pass of Petastorm DataLoader, reset underlying Petastorm reader to position 0.
step: 0 batch length: 32
step: 1 batch length: 4

I run it on petastorm==0.9.0, both pyspark==2.4.5 on my laptop and pyspark==3.0.0. The outputs are identical in both runs.

Warra07 commented 3 years ago

Same issue

Warra07 commented 3 years ago

So as a quick fix to this since i can't quite find where the error is coming from (tried resetting things manually gradually by switching some variables and reseting ventilator / reader).

the error doesn't seem to come from the reader itself since i even tried to completly change the reader variable manually from the dataloader with make_batch_reader.

so for now, i save the Torch Dataset Manager, and generate a new dataloader everytime i want to reset things...:

manager = converter.make_torch_dataloader(...)

for x in range(3):
  print("epoch", x)
  train_data_loader = manager.__enter__()
  i = iter(train_data_loader)
  for batch_number in range(n_batches):
    pd_batch = next(i)

manager.__exit__(None, None, None)

Or, this seems to work better as it exits the data loader and avoid errors.

manager = converter.make_torch_dataloader(...)

for x in range(3):
 with manager as train_data_loader:
    print("epoch", x)
    i = iter(train_data_loader)
    for batch_number in range(n_batches):
      pd_batch = next(i)

Note that in both codes, you'll need to specify num_epochs = 1 in the make_torch_dataloader function