pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.1k stars 22.41k forks source link

[Feature Request] Optionally specify batch dimension in DataLoader and collate_fn #53982

Open lodo1995 opened 3 years ago

lodo1995 commented 3 years ago

πŸš€ Feature

Right now collate_fn creates the batch dimension as the first dimension. It should be possible to pass a parameter to choose to insert the batch dimension in a different position.

Motivation

There are situations in which it would be useful to have the batch dimension not be the first. Recurrent Neural Networks come to mind, where having shape T x B x ... improves data locality compared to B x T x ....

It is of course possible to write a custom collate_fn function and pass that to the DataLoader. However, having to replicate the functionality of collate_fn (which has a lot of quality-of-life features) just to add a dim = 1 parameter in the call to torch.stack feels a bit unnecessary.

Furthermore, this new feature would not significantly increase the complexity of the implementation, nor negatively affect performance.

Pitch

A new parameter should be added to the DataLoader constructor, possibly called batch_dim, with default value 0 (which would replicate current behaviour). This parameter would be forwarded by the DataLoader to every call of collate_fn.

The same new parameter would be added to the default collate_fn function, which would pass it to itself in every recursive call and eventually pass it as the dim parameter to torch.stack.

According to this simple idea, batch_dim would be passed unconditionally to any collate_fn function, therefore causing backwards compatibility issues with those custom collate_fn that do not expect it. This issue could be overcome in two ways:

I am aware of current discussions on integrating collate_fn into the Dataset API. This feature is orthogonal to that and (as shown in the above paragraph) could integrate seamlessly with that.

Additional context

Here is how `collate_fn` could look like (based on commit 58eb233, changes highlighted) ```python def default_collate(batch, batch_dim = 0): #### <--- r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, batch_dim, out=out) #### <--- elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return default_collate([torch.as_tensor(b) for b in batch], batch_dim) #### <--- elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, collections.abc.Mapping): return {key: default_collate([d[key] for d in batch], batch_dim) for key in elem} #### <--- elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(default_collate(samples, batch_dim) for samples in zip(*batch))) #### <--- elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError('each element in list of batch should be of equal size') transposed = zip(*batch) return [default_collate(samples, batch_dim) for samples in transposed] #### <--- raise TypeError(default_collate_err_msg_format.format(elem_type)) ```
Here is how the `DataLoader` constructor could look like Current behaviour (based on commit 7f1693d) ```python def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False): ### OTHER STUFF HERE ### if collate_fn is None: if self._auto_collation: collate_fn = _utils.collate.default_collate else: collate_fn = _utils.collate.default_convert self.collate_fn = collate_fn ### OTHER STUFF HERE ### ``` New behaviour (based on second alternative suggested, default `batch_dim` value is `None`) ```python def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, batch_dim: Optional[int] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False): ### OTHER STUFF HERE ### if collate_fn is None: if self._auto_collation: collate_fn = _utils.collate.default_collate else: collate_fn = _utils.collate.default_convert if batch_dim is not None: self.collate_fn = functools.partial(collate_fn, batch_dim = batch_dim) else: self.collate_fn = collate_fn ### OTHER STUFF HERE ### ```

A similar (but more limited in scope) request was advanced in #10386

cc @SsnL @VitalyFedyunin @ejguan

VitalyFedyunin commented 3 years ago

Would be possible after migrating to DataPipes #49440