Open xichens opened 1 year ago
A follow up issue related to this one. Even without calling the wrap_data_loader
from BatchMemoryManager
, the DPDataLoader
has some issue handling empty batches in certain cases. Currently the wrapt_collate_with_empty
function here creates empty tensors for empty batches. However, it only sets the shape but not the dtype
for the empty tensors. By default, these empty tensors will be of float
. But some modules expect a particular data type as input, for example, for the torch.nn.Embedding
, the inputs should be of int or Long.
I think the wrapt_collate_with_empty
should consider the dtypes as well as shapes of the actual sample when creating the empty tensors.
🐛 Bug
When poisson sampling is used, empty batches can occur. However, the
BatchSplittingSampler
frombatch_memory_manager.py
, which is called when using theBatchMemoryManager
, cannot handle empty batches and will throw an error.To Reproduce
To reproduce it, see this colab link.
Expected behavior
The wrapped batch sampler should handle empty batches properly
Additional context
I think the issue is with this line When calling
the
batch_idxs
can be an empty list since it is from aUniformWithReplacementSampler
, butnp.array_split
does not expect the first arg to be empty.