We require users to specify what labels they want to retrieve from the loaded data. To realize this, collate_fn needs to support a desired_label argument. For example,
def collate_fn(batch, desired_labels):
new_batch = []
for img, label in batch:
if label in desired_labels:
filtered_batch.append((img, label))
return torch.utils.data.default_collate(new_batch)
We require users to specify what labels they want to retrieve from the loaded data. To realize this,
collate_fn
needs to support adesired_label
argument. For example,