hendrycks / ss-ood

Self-Supervised Learning for OOD Detection (NeurIPS 2019)
MIT License
266 stars 31 forks source link

Question about loss function (train.py code 168 line) #5

Closed waldstein94 closed 5 years ago

waldstein94 commented 5 years ago

Firstly, thank you for releasing your codes. It's very helpful for my research :)

I wonder if the objective function in train.py(168 line) is just about Rotation and Translation class. Because, in your paper, highest score on ImageNet was the result of trained by RotNet + Translation + Self-attention + Resize.

I hope you could answer for my question soon !

hendrycks commented 5 years ago

Self-attention is in the architecture and not the loss. We did not include the resize code since it only slightly improved performance.

resize_and_crop = trn.Compose([trn.Resize(256), trn.RandomCrop(224)])
resize_and_crop_and_zoom = trn.Compose([trn.Resize(256), trn.RandomCrop(224), trn.Resize(300), trn.CenterCrop(224)])

# ...

    def __getitem__(self, index):
        x, _ = self.dataset[index//num_perts]
        pert = pert_configs[index % num_perts]

        if pert[0] == 0:
            x = np.asarray(resize_and_crop(x))
        else:
            x = np.asarray(resize_and_crop_and_zoom(x))

        if np.random.uniform() < 0.5:
            x = x[:,::-1]

        label = [expanded_params[i].index(pert[i]) for i in range(len(expanded_params) - 1)]
        label = np.vstack((label + [0], label + [1], label + [2], label + [3]))
        x = trnF.to_tensor(x.copy()).unsqueeze(0).numpy()
        x = np.concatenate((x, np.rot90(x, 1, axes=(2, 3)), np.rot90(x, 2, axes=(2, 3)), np.rot90(x, 3, axes=(2, 3))), 0)
        return torch.FloatTensor(x), label