dvgodoy / PyTorchStepByStep

Official repository of my book: "Deep Learning with PyTorch Step-by-Step: A Beginner's Guide"
https://pytorchstepbystep.com
MIT License
834 stars 310 forks source link

TransformedTensorDataset needs a seed #33

Closed nesaboz closed 1 year ago

nesaboz commented 1 year ago

After seeing some inconsistencies with TransformedTensorDataset, finally found the solution:

class TransformedTensorDataset(Dataset):
    def __init__(self, x, y, transform=None, seed=42):
        self.x = x
        self.y = y
        self.transform = transform
        self.seed = seed

    def __getitem__(self, index):
        x = self.x[index]

        random.seed(self.seed)
        torch.manual_seed(self.seed)

        if self.transform:
            x = self.transform(x)

        return x, self.y[index]

    def __len__(self):
        return len(self.x)
nesaboz commented 1 year ago

Upon more insight, the consistency might not be advisable actually if randomness is expected. So while adding seeds is ok for experimentation reproducibility it should probably not be applied for augmentations in production.