Example usage
line 1 dataset = InfoBatch(your_original_training_dataset) line 2 data_loader = DataLoader(dataset, sampler=dataset.sampler) line 3 loss = torch.mean(dataset.update(loss))
Note loss should be no reduction version, i.e each sample should have a corresponding loss function value. loss shape should be of shape batch_size
Example usage
line 1 dataset = InfoBatch(your_original_training_dataset) line 2 data_loader = DataLoader(dataset, sampler=dataset.sampler) line 3 loss = torch.mean(dataset.update(loss))
Note loss should be no reduction version, i.e each sample should have a corresponding loss function value. loss shape should be of shape batch_size