pytorch / opacus

Training PyTorch models with differential privacy
https://opacus.ai
Apache License 2.0
1.67k stars 332 forks source link

Support empty batches for arbitrary dataset structures #534

Open ffuuugor opened 1 year ago

ffuuugor commented 1 year ago

For context see discussion in #530 (and thanks @joserapa98 for pointing out the issue)

At the moment (to be precise, after #530 will have been merged) Opacus can support empty batches only for datasets with a simple structure - every record should be a tuple of a simple type: either tensor or a primitive type.

For instance, datasets with records like this (Tensor, int) or this (Tensor, Tensor) are supported. However datasets like this (Tensor, (int, int)) are not.

Pytorch adresses similar problem with the following piece of code:

if isinstance(elem, collections.abc.Mapping):
    try:
        return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
    except TypeError:
        # The mapping type may not support `__init__(iterable)`.
        return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
    return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
    # check to make sure that the elements in batch have consistent size
    it = iter(batch)
    elem_size = len(next(it))
    if not all(len(elem) == elem_size for elem in it):
        raise RuntimeError('each element in list of batch should be of equal size')
    transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.

    if isinstance(elem, tuple):
        return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
    else:
        try:
            return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
        except TypeError:
            # The sequence type may not support `__init__(iterable)` (e.g., `range`).
            return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]

We need to adapt it to our needs and make sure DPDataLoader can handle datasets of arbitrary structure.

Relevant code pointer: https://github.com/pytorch/opacus/blob/7393ae47fdf824ad65d5035461dc391c0f4cc932/opacus/data_loader.py#L31