Cysu / open-reid

Open source person re-identification library in python
https://cysu.github.io/open-reid/
MIT License
1.33k stars 349 forks source link

OIM loss #90

Open khadijakhaldi opened 4 years ago

khadijakhaldi commented 4 years ago

Why in the OIM loss we need to define the class OIM which extends Function class ? Normally pytorch does the backward for us, so we don't need to write the backward function. Thank you.

X-funbean commented 4 years ago

I believe you are right. Also, I find that the 'momentum' seems useless, for I get the same output no matter what value I set. I write a test code based on test/loss/test_oim.py. Notice I'm using the latest version of PyTorch, so I made several modifications.

from __future__ import absolute_import

import torch
import torch.nn.functional as F
from torch import nn, autograd
from torch.autograd import Variable

class OIM(autograd.Function):

    # def __init__(self, lut, momentum=0.5):
    #     super(OIM, self).__init__()
    #     self.lut = lut
    #     self.momentum = momentum

    @staticmethod
    def forward(ctx, inputs, targets, lut, momentum=0.5):
        ctx.save_for_backward(inputs, targets)
        ctx.lut = lut
        ctx.momentum = momentum

        outputs = inputs.mm(lut.t())
        return outputs

    @staticmethod
    def backward(ctx, grad_outputs):
        inputs, targets = ctx.saved_tensors
        grad_inputs = None
        if ctx.needs_input_grad[0]:
            print(ctx.needs_input_grad)
            grad_inputs = grad_outputs.mm(ctx.lut)

        for x, y in zip(inputs, targets):
            ctx.lut[y] = ctx.momentum * ctx.lut[y] + (1. - ctx.momentum) * x
            ctx.lut[y] /= ctx.lut[y].norm()
        return grad_inputs, None, None, None

class OIMLoss(nn.Module):
    def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5,
                 weight=None, reduction='mean'):
        super(OIMLoss, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.momentum = momentum
        self.scalar = scalar
        self.weight = weight
        self.reduction = reduction

        self.register_buffer('lut', torch.zeros(num_classes, num_features))

        self.oim = OIM.apply

    def forward(self, inputs, targets):
        # inputs = oim(inputs, targets, self.lut, momentum=self.momentum)
        inputs = self.oim(inputs, targets, self.lut, self.momentum)
        inputs *= self.scalar
        loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
        return loss
        # return loss, inputs

class Test(nn.Module):
    def __init__(self, num_features, num_classes, scalar=1.0, weight=None, reduction='mean'):
        super(Test, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.scalar = scalar
        self.weight = weight
        self.reduction = reduction

        self.register_buffer('lut', torch.zeros(num_classes, num_features))

    def forward(self, inputs, targets):
        inputs = inputs.mm(self.lut.t())
        inputs *= self.scalar
        loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
        return loss

if __name__ == '__main__':
    criterion = OIMLoss(3, 3, scalar=1.0, reduction='sum', momentum=0.9)
    # criterion_2 = OIMLoss(3, 3, scalar=1.0, reduction='sum', momentum=0)
    criterion_2 = Test(3, 3, scalar=1.0, reduction='sum')

    seed = 2018
    criterion.lut = torch.eye(3)
    criterion_2.lut = torch.eye(3)
    torch.manual_seed(seed)

    x = torch.randn(3, 3, requires_grad=True)
    y = torch.arange(0, 3)

    loss = criterion(x, y)
    loss.backward()
    probs = F.softmax(x, dim=-1)
    grads = probs.data - torch.eye(3)
    abs_diff = torch.abs(grads - x.grad.data)

    print(probs)
    print(grads)
    print(abs_diff)

    print('*' * 50)

    torch.manual_seed(seed)
    x = torch.randn(3, 3, requires_grad=True)
    y = torch.arange(0, 3)
    loss = criterion_2(x, y)
    loss.backward()
    probs = F.softmax(x, dim=-1)
    grads = probs.data - torch.eye(3)
    abs_diff = torch.abs(grads - x.grad.data)

    print(probs)
    print(grads)
    print(abs_diff)

and the output is

(True, False, False, False)
tensor([[0.6779, 0.2672, 0.0548],
        [0.1680, 0.2574, 0.5747],
        [0.1614, 0.3012, 0.5374]], grad_fn=<SoftmaxBackward>)
tensor([[-0.3221,  0.2672,  0.0548],
        [ 0.1680, -0.7426,  0.5747],
        [ 0.1614,  0.3012, -0.4626]])
tensor([[0.0000e+00, 0.0000e+00, 3.7253e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.4901e-08, 0.0000e+00, 5.9605e-08]])
**************************************************
tensor([[0.6779, 0.2672, 0.0548],
        [0.1680, 0.2574, 0.5747],
        [0.1614, 0.3012, 0.5374]], grad_fn=<SoftmaxBackward>)
tensor([[-0.3221,  0.2672,  0.0548],
        [ 0.1680, -0.7426,  0.5747],
        [ 0.1614,  0.3012, -0.4626]])
tensor([[0.0000e+00, 0.0000e+00, 3.7253e-09],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.4901e-08, 0.0000e+00, 5.9605e-08]])

I think this version of OIM loss requires being perfected, because it seems that this code is widely used in person search. Thanks!