lukasruff / Deep-SVDD-PyTorch

A PyTorch implementation of the Deep SVDD anomaly detection method
MIT License
698 stars 197 forks source link

PyTorch transforms.Normalize() usage #13

Closed kumarneelabh13 closed 5 years ago

kumarneelabh13 commented 5 years ago

According to PyTorch docs, the Normalize transform expects the mean and std for every channel.

CLASS torchvision.transforms.Normalize(mean, std, inplace=False)

But currently, this implementation of Deep SVDD passes the "min" value in place of "mean" and "max - min" value in place of std. And that too, for only one channel even in case of CIFAR-10.

from datasets/cifar10.py line 35

> transforms.Normalize([min_max[normal_class][0]] * 3,
                       [min_max[normal_class][1] - min_max[normal_class][0]] * 3)])

Is this intentional or a real issue?

kumarneelabh13 commented 5 years ago

I got it. You've used the Normalize() function to implement min-max scaling.