NUS-HPC-AI-Lab / InfoBatch

Lossless Training Speed Up by Unbiased Dynamic Data Pruning
318 stars 18 forks source link

A new version with only 3 lines of change comes! #12

Closed tiandunx closed 10 months ago

tiandunx commented 10 months ago

Example usage line 1 dataset = InfoBatch(your_original_training_dataset) line 2 data_loader = DataLoader(dataset, sampler=dataset.sampler) line 3 loss = torch.mean(dataset.update(loss)) Note loss should be no reduction version, i.e each sample should have a corresponding loss function value. loss shape should be of shape batch_size