Closed Shiro-LK closed 4 years ago
@davidel can confirm but I think we currently don't support load multiple tfrecords at once. For the loader part, I think in multi process case you should follow the parallel_loader example we have in here
You might find this tfRecordReader example useful too.
The TF record reader is an enumeration interface. You keep calling read_example()
until you get an exception:
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Hi ,I am having trouble with this. I am using TFRecordReader
inside a torch IterableDataset
but then once I input the Dataset to the DataLoader it starts conflicting with the DistributedSampler
. Is there a standard way to bypass the use of the Sampler for this case?
Can you post the implementation of your Dataset and the specific error you get? It is hard to help without more information 😉
Hi @davidel thank you for your response. I will shar ewith you a kaggle kernel where the issue is made evident.
I am also copying here part of the code in question.
# I use this function to get the TFRecord iterator
def get_dataset(files, batch_size=16, repeat=False, cache=False, shuffle=False, labeled=True, return_image_ids=False):
ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTO, compression_type="GZIP")
if cache:
# You'll need around 15GB RAM if you'd like to cache val dataset, and 50~60GB RAM for train dataset.
ds = ds.cache()
if repeat:
ds = ds.repeat()
if shuffle:
ds = ds.shuffle(1024 * 2)
opt = tf.data.Options()
opt.experimental_deterministic = False
ds = ds.with_options(opt)
if labeled:
ds = ds.map(read_labeled_tfrecord, num_parallel_calls=AUTO)
else:
ds = ds.map(lambda example: read_unlabeled_tfrecord(example, return_image_ids), num_parallel_calls=AUTO)
ds = ds.batch(batch_size)
ds = ds.prefetch(AUTO)
return tfds.as_numpy(ds)
#Then I create an iterable dataset
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self,files, batch_size=16, repeat=False, cache=False, shuffle=False, labeled=True, return_image_ids=False):
super(MyIterableDataset).__init__()
self.ds = get_dataset(files, batch_size=16, repeat=False, cache=False, shuffle=False, labeled=True, return_image_ids=False)
def __iter__(self):
return iter(self.ds)
Up to there everything works well, but then I do not lnow what to do with the Sampler
so that the dataset works in Parallel
mode. I think I am missing something conceptually here.
What I try is doing the following
train_dataset = MyIterableDataset(files_train,batch_size=CFG.batch_size,
shuffle=True,
cache=False, repeat=True)
valid_dataset = MyIterableDataset(files_valid,batch_size=CFG.batch_size,
shuffle=True,
cache=False, repeat=True)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
train_loader = DataLoader(train_dataset,
batch_size=None,
sampler=train_sampler,
num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
para_loader = pl.ParallelLoader(valid_loader, [device])
It seems the snippet you posted uses tf.data.TFRecordDataset which is a different code base. Internally they both end up into the same TF C++ code, but the upper API layer is different.
It seems the snippet you posted uses tf.data.TFRecordDataset which is a different code base. Internally they both end up into the same TF C++ code, but the upper API layer is different.
Yes but, when using TfRecordReader
, and then creating the IterableDataset
from it I run into a similar issue. And the issue is in properly instantiating the Sampler
and or Loader
.
In particular the error is related to IterableDatasets
not having __len__
defined. I could just define it, but it does not seem correct.
Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
_start_fn(index, pf_cfg, fn, args)
File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
fn(gindex, *args)
File "/tmp/ipykernel_43/1374819804.py", line 3, in _mp_fn
a = train_loop()
File "/tmp/ipykernel_43/385688159.py", line 54, in train_loop
shuffle=True)
File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/distributed.py", line 91, in __init__
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore
TypeError: object of type 'MyIterableDataset' has no len()```
I see the issue. The Sampler wants a len() function in order to partition the sample IDs. In that case an IterableDataset is not probably going to work. One simple solution, if you have enough memory, is to slurp the whole dataset within a Python list, and create a normal Dataset with len() and getitem() APIs.
My dataset is very heavy. So I do not think that will cut it. I am looking at the WebDataset
library that works fine with XLA
from the examples I have seen. The only problem there is that I can not use the TFRecord
format. There they use an IterableDataset
. I will try to see if I can leverage part of their code.
Essentially a TF-Record file is a streaming-like format, and things like the Sampler or even a simple shuffle operation, wants to do random access. IIRC a TF-Record file is a single compressed file, which means it is not even seek-able. You could rewrite the TF-Record file into a file which supports random ops. This is some code I did when we had users reporting high memory usage on Colab due to large datasets. This rewrites the dataset into one which can be read one by one, not requiring to load the whole thing in memory. This was used to lower memory usage, but can also work in converting a streaming TF-Record file into a random-access capable one. You simply write a converter which uses the PyTorch/XLA TF record API, and then write_sample() it. At the end you have a data file and an index file, and you can use the FileDataset class as Dataset.
def write_sample(s, data_file, index_file):
bio = io.BytesIO()
torch.save(s, bio)
offset = data_file.tell()
index_file.write((offset).to_bytes(8, byteorder='little'))
data_file.write(bio.getvalue())
class FileDataset(object):
def __init__(self, path):
self._data_file = open(path + '.data', 'rb')
self._index_file = open(path + '.index', 'rb')
self._index_file.seek(0, 2)
self._index_size = self._index_file.tell()
assert self._index_size % 8 == 0
self._data_file.seek(0, 2)
self._data_size = self._data_file.tell()
def __getitem__(self, idx):
index_offset = idx * 8
assert index_offset < self._index_size
self._index_file.seek(index_offset)
data_offset = int.from_bytes(self._index_file.read(8),
byteorder='little')
if index_offset + 8 <= self._index_size:
next_offset = int.from_bytes(self._index_file.read(8),
byteorder='little')
else:
next_offset = self._data_size
self._data_file.seek(data_offset)
sample_data = self._data_file.read(next_offset - data_offset)
return torch.load(io.BytesIO(sample_data))
def __len__(self):
return self._index_size // 8
❓ Questions and Help
Hi,
I am interesting by the feature of using tfrecords. I see it was possible to use it looking at the documentation. But, I did not find a way to load multiple tfrecords at once. I would like to use tfrecord using tpus, so I have 8 processes running. Do I have to use the same number of dataloader/sampler than tfrecords I have ? Or is there a way to load multiples tfrecords for one dataloader/sampler
Thanks for the help !