Add a transform to make samples 0 mean and unit standard deviation.
aka, reintroduce this into our current transform module.
def preprocess_images(images, patch_size_x, patch_size_y):
data = images.detach().clone()
means = torch.mean(data, dim=(1, 2, 3), keepdims=True)
data = data - means
stds = 10torch.std(data, dim=(1, 2, 3), keepdims=True)
data = data / stds
data = data.reshape(-1, patch_size_xpatch_size_y)
return data
Add a transform to make samples 0 mean and unit standard deviation.
aka, reintroduce this into our current
transform
module.def preprocess_images(images, patch_size_x, patch_size_y): data = images.detach().clone() means = torch.mean(data, dim=(1, 2, 3), keepdims=True) data = data - means stds = 10torch.std(data, dim=(1, 2, 3), keepdims=True) data = data / stds data = data.reshape(-1, patch_size_xpatch_size_y) return data