When the batch contains various Tensors (is the case of DGL Graphs) it may happen that the first Tensor is not useful to infer the batch size, producing inconsistent behavior (python dictionary doesn't necessarily return sorted keys).
I propose inferring the number of samples in the batch by looking at the length of data.sample_id if it is present and the current behavior otherwise
if 'data.sample_id' in keys:
batch_size = len(batch['data.sample_id'])
else:
batch_size = None
if batch_size is None:
for key in keys:
if isinstance(batch[key], torch.Tensor):
batch_size = len(batch[key])
break
https://github.com/BiomedSciAI/fuse-med-ml/blob/fb11d7ecdf9cec9fda5a7a36aa5cff694cdb3168/fuse/utils/data/collate.py#L123-L127
When the batch contains various Tensors (is the case of DGL Graphs) it may happen that the first Tensor is not useful to infer the batch size, producing inconsistent behavior (python dictionary doesn't necessarily return sorted keys).
I propose inferring the number of samples in the batch by looking at the length of
data.sample_id
if it is present and the current behavior otherwise