pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

TFrecord with torch xla #2434

Closed Shiro-LK closed 4 years ago

Shiro-LK commented 4 years ago

❓ Questions and Help

Hi,

I am interesting by the feature of using tfrecords. I see it was possible to use it looking at the documentation. But, I did not find a way to load multiple tfrecords at once. I would like to use tfrecord using tpus, so I have 8 processes running. Do I have to use the same number of dataloader/sampler than tfrecords I have ? Or is there a way to load multiples tfrecords for one dataloader/sampler

Thanks for the help !

JackCaoG commented 4 years ago

@davidel can confirm but I think we currently don't support load multiple tfrecords at once. For the loader part, I think in multi process case you should follow the parallel_loader example we have in here

JackCaoG commented 4 years ago

You might find this tfRecordReader example useful too.

davidel commented 4 years ago

The TF record reader is an enumeration interface. You keep calling read_example() until you get an exception:

https://github.com/pytorch/xla/blob/0c34871b91159e51603119c42d9fe58fb7d1da88/torch_xla/utils/tf_record_reader.py#L8

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

MrRobot2211 commented 2 years ago

Hi ,I am having trouble with this. I am using TFRecordReader inside a torch IterableDataset but then once I input the Dataset to the DataLoader it starts conflicting with the DistributedSampler. Is there a standard way to bypass the use of the Sampler for this case?

davidel commented 2 years ago

Can you post the implementation of your Dataset and the specific error you get? It is hard to help without more information 😉

MrRobot2211 commented 2 years ago

Hi @davidel thank you for your response. I will shar ewith you a kaggle kernel where the issue is made evident.

I am also copying here part of the code in question.

# I use this function to get the TFRecord iterator
def get_dataset(files, batch_size=16, repeat=False, cache=False, shuffle=False, labeled=True, return_image_ids=False):
    ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO, compression_type="GZIP")
    if cache:
        # You'll need around 15GB RAM if you'd like to cache val dataset, and 50~60GB RAM for train dataset.
        ds = ds.cache()

    if repeat:
        ds = ds.repeat()

    if shuffle:
        ds = ds.shuffle(1024 * 2)
        opt = tf.data.Options()
        opt.experimental_deterministic = False
        ds = ds.with_options(opt)

    if labeled:
        ds = ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO)
    else:
        ds = ds.map(lambda example: read_unlabeled_tfrecord(example, return_image_ids), num_parallel_calls=AUTO)

    ds = ds.batch(batch_size)
    ds = ds.prefetch(AUTO)
    return tfds.as_numpy(ds)

#Then I create an iterable dataset
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self,files, batch_size=16, repeat=False, cache=False, shuffle=False, labeled=True, return_image_ids=False):
        super(MyIterableDataset).__init__()
        self.ds = get_dataset(files, batch_size=16, repeat=False, cache=False, shuffle=False, labeled=True, return_image_ids=False)

    def __iter__(self):

        return iter(self.ds)

Up to there everything works well, but then I do not lnow what to do with the Sampler so that the dataset works in Parallel mode. I think I am missing something conceptually here.

What I try is doing the following

 train_dataset = MyIterableDataset(files_train,batch_size=CFG.batch_size, 
                                    shuffle=True, 
                                    cache=False, repeat=True)

        valid_dataset = MyIterableDataset(files_valid,batch_size=CFG.batch_size, 
                                    shuffle=True, 
                                    cache=False, repeat=True)

        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)

       train_loader = DataLoader(train_dataset,
                                  batch_size=None,
                                  sampler=train_sampler,
                                  num_workers=CFG.num_workers, pin_memory=True, drop_last=True)

para_loader = pl.ParallelLoader(valid_loader, [device])
davidel commented 2 years ago

It seems the snippet you posted uses tf.data.TFRecordDataset which is a different code base. Internally they both end up into the same TF C++ code, but the upper API layer is different.

MrRobot2211 commented 2 years ago

It seems the snippet you posted uses tf.data.TFRecordDataset which is a different code base. Internally they both end up into the same TF C++ code, but the upper API layer is different.

Yes but, when using TfRecordReader, and then creating the IterableDataset from it I run into a similar issue. And the issue is in properly instantiating the Sampler and or Loader.

In particular the error is related to IterableDatasets not having __len__ defined. I could just define it, but it does not seem correct.


Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    fn(gindex, *args)
  File "/tmp/ipykernel_43/1374819804.py", line 3, in _mp_fn
    a = train_loop()
  File "/tmp/ipykernel_43/385688159.py", line 54, in train_loop
    shuffle=True)
  File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/distributed.py", line 91, in __init__
    self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore
TypeError: object of type 'MyIterableDataset' has no len()```
davidel commented 2 years ago

I see the issue. The Sampler wants a len() function in order to partition the sample IDs. In that case an IterableDataset is not probably going to work. One simple solution, if you have enough memory, is to slurp the whole dataset within a Python list, and create a normal Dataset with len() and getitem() APIs.

MrRobot2211 commented 2 years ago

My dataset is very heavy. So I do not think that will cut it. I am looking at the WebDataset library that works fine with XLA from the examples I have seen. The only problem there is that I can not use the TFRecord format. There they use an IterableDataset. I will try to see if I can leverage part of their code.

davidel commented 2 years ago

Essentially a TF-Record file is a streaming-like format, and things like the Sampler or even a simple shuffle operation, wants to do random access. IIRC a TF-Record file is a single compressed file, which means it is not even seek-able. You could rewrite the TF-Record file into a file which supports random ops. This is some code I did when we had users reporting high memory usage on Colab due to large datasets. This rewrites the dataset into one which can be read one by one, not requiring to load the whole thing in memory. This was used to lower memory usage, but can also work in converting a streaming TF-Record file into a random-access capable one. You simply write a converter which uses the PyTorch/XLA TF record API, and then write_sample() it. At the end you have a data file and an index file, and you can use the FileDataset class as Dataset.

def write_sample(s, data_file, index_file):
    bio = io.BytesIO()
    torch.save(s, bio)
    offset = data_file.tell()
    index_file.write((offset).to_bytes(8, byteorder='little'))
    data_file.write(bio.getvalue())

class FileDataset(object):
    def __init__(self, path):
        self._data_file = open(path + '.data', 'rb')
        self._index_file = open(path + '.index', 'rb')
        self._index_file.seek(0, 2)
        self._index_size = self._index_file.tell()
        assert self._index_size % 8 == 0
        self._data_file.seek(0, 2)
        self._data_size = self._data_file.tell()

    def __getitem__(self, idx):
        index_offset = idx * 8
        assert index_offset < self._index_size
        self._index_file.seek(index_offset)
        data_offset = int.from_bytes(self._index_file.read(8),
                                     byteorder='little')
        if index_offset + 8 <= self._index_size:
            next_offset = int.from_bytes(self._index_file.read(8),
                                         byteorder='little')
        else:
            next_offset = self._data_size
        self._data_file.seek(data_offset)
        sample_data = self._data_file.read(next_offset - data_offset)
        return torch.load(io.BytesIO(sample_data))

    def __len__(self):
        return self._index_size // 8
tottenjordan commented 2 years ago

@MrRobot2211 - This blog demonstrates a distributed training example using GCS and webdataset. Here is a link to the code referenced in this blog