Right now collate_fn creates the batch dimension as the first dimension. It should be possible to pass a parameter to choose to insert the batch dimension in a different position.
Motivation
There are situations in which it would be useful to have the batch dimension not be the first. Recurrent Neural Networks come to mind, where having shape T x B x ... improves data locality compared to B x T x ....
It is of course possible to write a custom collate_fn function and pass that to the DataLoader. However, having to replicate the functionality of collate_fn (which has a lot of quality-of-life features) just to add a dim = 1 parameter in the call to torch.stack feels a bit unnecessary.
Furthermore, this new feature would not significantly increase the complexity of the implementation, nor negatively affect performance.
Pitch
A new parameter should be added to the DataLoader constructor, possibly called batch_dim, with default value 0 (which would replicate current behaviour). This parameter would be forwarded by the DataLoader to every call of collate_fn.
The same new parameter would be added to the default collate_fn function, which would pass it to itself in every recursive call and eventually pass it as the dim parameter to torch.stack.
According to this simple idea, batch_dim would be passed unconditionally to any collate_fn function, therefore causing backwards compatibility issues with those custom collate_fn that do not expect it. This issue could be overcome in two ways:
make the batch_dim parameter of DataLoader mutually exclusive with the collate_fn parameter, therefore allowing the specification of a custom batch dimension only for the default collate_fn. However, this feels a bit ad hoc. Also, the feature might be useful for custom functions as well, especially if the collate_fn function gets integrated in the Dataset API. In that case, passing batch_dim to it would allow the same Dataset to be used with models that require different dimension orders.
make the default value of the batch_dim parameter (of the DataLoader) be None, and make its forwarding to collate_fn conditional to its value being not None. In this way, backwards compatibility is ensured, while allowing users to opt-in to this feature for their custom collate_fn functions.
I am aware of current discussions on integrating collate_fn into the Dataset API. This feature is orthogonal to that and (as shown in the above paragraph) could integrate seamlessly with that.
Additional context
Here is how `collate_fn` could look like (based on commit 58eb233, changes highlighted)
```python
def default_collate(batch, batch_dim = 0): #### <---
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, batch_dim, out=out) #### <---
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch], batch_dim) #### <---
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch], batch_dim) for key in elem} #### <---
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples, batch_dim) for samples in zip(*batch))) #### <---
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples, batch_dim) for samples in transposed] #### <---
raise TypeError(default_collate_err_msg_format.format(elem_type))
```
Here is how the `DataLoader` constructor could look like
Current behaviour (based on commit 7f1693d)
```python
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
### OTHER STUFF HERE ###
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
self.collate_fn = collate_fn
### OTHER STUFF HERE ###
```
New behaviour (based on second alternative suggested, default `batch_dim` value is `None`)
```python
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
batch_sampler: Optional[Sampler[Sequence[int]]] = None,
num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
batch_dim: Optional[int] = None,
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
multiprocessing_context=None, generator=None,
*, prefetch_factor: int = 2,
persistent_workers: bool = False):
### OTHER STUFF HERE ###
if collate_fn is None:
if self._auto_collation:
collate_fn = _utils.collate.default_collate
else:
collate_fn = _utils.collate.default_convert
if batch_dim is not None:
self.collate_fn = functools.partial(collate_fn, batch_dim = batch_dim)
else:
self.collate_fn = collate_fn
### OTHER STUFF HERE ###
```
A similar (but more limited in scope) request was advanced in #10386
π Feature
Right now
collate_fn
creates the batch dimension as the first dimension. It should be possible to pass a parameter to choose to insert the batch dimension in a different position.Motivation
There are situations in which it would be useful to have the batch dimension not be the first. Recurrent Neural Networks come to mind, where having shape
T x B x ...
improves data locality compared toB x T x ...
.It is of course possible to write a custom
collate_fn
function and pass that to theDataLoader
. However, having to replicate the functionality ofcollate_fn
(which has a lot of quality-of-life features) just to add adim = 1
parameter in the call totorch.stack
feels a bit unnecessary.Furthermore, this new feature would not significantly increase the complexity of the implementation, nor negatively affect performance.
Pitch
A new parameter should be added to the
DataLoader
constructor, possibly calledbatch_dim
, with default value0
(which would replicate current behaviour). This parameter would be forwarded by theDataLoader
to every call ofcollate_fn
.The same new parameter would be added to the default
collate_fn
function, which would pass it to itself in every recursive call and eventually pass it as thedim
parameter totorch.stack
.According to this simple idea,
batch_dim
would be passed unconditionally to anycollate_fn
function, therefore causing backwards compatibility issues with those customcollate_fn
that do not expect it. This issue could be overcome in two ways:batch_dim
parameter ofDataLoader
mutually exclusive with thecollate_fn
parameter, therefore allowing the specification of a custom batch dimension only for the defaultcollate_fn
. However, this feels a bit ad hoc. Also, the feature might be useful for custom functions as well, especially if thecollate_fn
function gets integrated in theDataset
API. In that case, passingbatch_dim
to it would allow the sameDataset
to be used with models that require different dimension orders.batch_dim
parameter (of theDataLoader
) beNone
, and make its forwarding tocollate_fn
conditional to its value being notNone
. In this way, backwards compatibility is ensured, while allowing users to opt-in to this feature for their customcollate_fn
functions.I am aware of current discussions on integrating
collate_fn
into theDataset
API. This feature is orthogonal to that and (as shown in the above paragraph) could integrate seamlessly with that.Additional context
Here is how `collate_fn` could look like (based on commit 58eb233, changes highlighted)
```python def default_collate(batch, batch_dim = 0): #### <--- r"""Puts each data field into a tensor with outer dimension batch size""" elem = batch[0] elem_type = type(elem) if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: # If we're in a background process, concatenate directly into a # shared memory tensor to avoid an extra copy numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, batch_dim, out=out) #### <--- elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ and elem_type.__name__ != 'string_': if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': # array of string classes and object if np_str_obj_array_pattern.search(elem.dtype.str) is not None: raise TypeError(default_collate_err_msg_format.format(elem.dtype)) return default_collate([torch.as_tensor(b) for b in batch], batch_dim) #### <--- elif elem.shape == (): # scalars return torch.as_tensor(batch) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float64) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, string_classes): return batch elif isinstance(elem, collections.abc.Mapping): return {key: default_collate([d[key] for d in batch], batch_dim) for key in elem} #### <--- elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple return elem_type(*(default_collate(samples, batch_dim) for samples in zip(*batch))) #### <--- elif isinstance(elem, collections.abc.Sequence): # check to make sure that the elements in batch have consistent size it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError('each element in list of batch should be of equal size') transposed = zip(*batch) return [default_collate(samples, batch_dim) for samples in transposed] #### <--- raise TypeError(default_collate_err_msg_format.format(elem_type)) ```Here is how the `DataLoader` constructor could look like
Current behaviour (based on commit 7f1693d) ```python def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False): ### OTHER STUFF HERE ### if collate_fn is None: if self._auto_collation: collate_fn = _utils.collate.default_collate else: collate_fn = _utils.collate.default_convert self.collate_fn = collate_fn ### OTHER STUFF HERE ### ``` New behaviour (based on second alternative suggested, default `batch_dim` value is `None`) ```python def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, shuffle: bool = False, sampler: Optional[Sampler[int]] = None, batch_sampler: Optional[Sampler[Sequence[int]]] = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, batch_dim: Optional[int] = None, pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None, multiprocessing_context=None, generator=None, *, prefetch_factor: int = 2, persistent_workers: bool = False): ### OTHER STUFF HERE ### if collate_fn is None: if self._auto_collation: collate_fn = _utils.collate.default_collate else: collate_fn = _utils.collate.default_convert if batch_dim is not None: self.collate_fn = functools.partial(collate_fn, batch_dim = batch_dim) else: self.collate_fn = collate_fn ### OTHER STUFF HERE ### ```A similar (but more limited in scope) request was advanced in #10386
cc @SsnL @VitalyFedyunin @ejguan