BiomedSciAI / fuse-med-ml

A python framework accelerating ML based discovery in the medical field by encouraging code reuse. Batteries included :)
Apache License 2.0
134 stars 34 forks source link

Error in complex batches #210

Closed afoncubierta closed 1 year ago

afoncubierta commented 1 year ago

https://github.com/BiomedSciAI/fuse-med-ml/blob/fb11d7ecdf9cec9fda5a7a36aa5cff694cdb3168/fuse/utils/data/collate.py#L123-L127

When the batch contains various Tensors (is the case of DGL Graphs) it may happen that the first Tensor is not useful to infer the batch size, producing inconsistent behavior (python dictionary doesn't necessarily return sorted keys).

I propose inferring the number of samples in the batch by looking at the length of data.sample_id if it is present and the current behavior otherwise

    if 'data.sample_id' in keys:
        batch_size = len(batch['data.sample_id'])
    else:
        batch_size = None

    if batch_size is None:
        for key in keys:
            if isinstance(batch[key], torch.Tensor):
                batch_size = len(batch[key])
                break