SaashaJoshi / piQture

piQture: A quantum machine learning library for image processing.
https://saashajoshi.github.io/piQture/
Apache License 2.0
13 stars 6 forks source link

Add `desired_labels` argument to `collate_fn` in `data_loader` #87

Closed SaashaJoshi closed 2 months ago

SaashaJoshi commented 2 months ago

We require users to specify what labels they want to retrieve from the loaded data. To realize this, collate_fn needs to support a desired_label argument. For example,

def collate_fn(batch, desired_labels):
    new_batch = []
    for img, label in batch:
        if label in desired_labels:
            filtered_batch.append((img, label))
    return torch.utils.data.default_collate(new_batch)