Open khadijakhaldi opened 4 years ago
I believe you are right. Also, I find that the 'momentum' seems useless, for I get the same output no matter what value I set.
I write a test code based on test/loss/test_oim.py
. Notice I'm using the latest version of PyTorch, so I made several modifications.
from __future__ import absolute_import
import torch
import torch.nn.functional as F
from torch import nn, autograd
from torch.autograd import Variable
class OIM(autograd.Function):
# def __init__(self, lut, momentum=0.5):
# super(OIM, self).__init__()
# self.lut = lut
# self.momentum = momentum
@staticmethod
def forward(ctx, inputs, targets, lut, momentum=0.5):
ctx.save_for_backward(inputs, targets)
ctx.lut = lut
ctx.momentum = momentum
outputs = inputs.mm(lut.t())
return outputs
@staticmethod
def backward(ctx, grad_outputs):
inputs, targets = ctx.saved_tensors
grad_inputs = None
if ctx.needs_input_grad[0]:
print(ctx.needs_input_grad)
grad_inputs = grad_outputs.mm(ctx.lut)
for x, y in zip(inputs, targets):
ctx.lut[y] = ctx.momentum * ctx.lut[y] + (1. - ctx.momentum) * x
ctx.lut[y] /= ctx.lut[y].norm()
return grad_inputs, None, None, None
class OIMLoss(nn.Module):
def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5,
weight=None, reduction='mean'):
super(OIMLoss, self).__init__()
self.num_features = num_features
self.num_classes = num_classes
self.momentum = momentum
self.scalar = scalar
self.weight = weight
self.reduction = reduction
self.register_buffer('lut', torch.zeros(num_classes, num_features))
self.oim = OIM.apply
def forward(self, inputs, targets):
# inputs = oim(inputs, targets, self.lut, momentum=self.momentum)
inputs = self.oim(inputs, targets, self.lut, self.momentum)
inputs *= self.scalar
loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
return loss
# return loss, inputs
class Test(nn.Module):
def __init__(self, num_features, num_classes, scalar=1.0, weight=None, reduction='mean'):
super(Test, self).__init__()
self.num_features = num_features
self.num_classes = num_classes
self.scalar = scalar
self.weight = weight
self.reduction = reduction
self.register_buffer('lut', torch.zeros(num_classes, num_features))
def forward(self, inputs, targets):
inputs = inputs.mm(self.lut.t())
inputs *= self.scalar
loss = F.cross_entropy(inputs, targets, weight=self.weight, reduction=self.reduction)
return loss
if __name__ == '__main__':
criterion = OIMLoss(3, 3, scalar=1.0, reduction='sum', momentum=0.9)
# criterion_2 = OIMLoss(3, 3, scalar=1.0, reduction='sum', momentum=0)
criterion_2 = Test(3, 3, scalar=1.0, reduction='sum')
seed = 2018
criterion.lut = torch.eye(3)
criterion_2.lut = torch.eye(3)
torch.manual_seed(seed)
x = torch.randn(3, 3, requires_grad=True)
y = torch.arange(0, 3)
loss = criterion(x, y)
loss.backward()
probs = F.softmax(x, dim=-1)
grads = probs.data - torch.eye(3)
abs_diff = torch.abs(grads - x.grad.data)
print(probs)
print(grads)
print(abs_diff)
print('*' * 50)
torch.manual_seed(seed)
x = torch.randn(3, 3, requires_grad=True)
y = torch.arange(0, 3)
loss = criterion_2(x, y)
loss.backward()
probs = F.softmax(x, dim=-1)
grads = probs.data - torch.eye(3)
abs_diff = torch.abs(grads - x.grad.data)
print(probs)
print(grads)
print(abs_diff)
and the output is
(True, False, False, False)
tensor([[0.6779, 0.2672, 0.0548],
[0.1680, 0.2574, 0.5747],
[0.1614, 0.3012, 0.5374]], grad_fn=<SoftmaxBackward>)
tensor([[-0.3221, 0.2672, 0.0548],
[ 0.1680, -0.7426, 0.5747],
[ 0.1614, 0.3012, -0.4626]])
tensor([[0.0000e+00, 0.0000e+00, 3.7253e-09],
[0.0000e+00, 0.0000e+00, 0.0000e+00],
[1.4901e-08, 0.0000e+00, 5.9605e-08]])
**************************************************
tensor([[0.6779, 0.2672, 0.0548],
[0.1680, 0.2574, 0.5747],
[0.1614, 0.3012, 0.5374]], grad_fn=<SoftmaxBackward>)
tensor([[-0.3221, 0.2672, 0.0548],
[ 0.1680, -0.7426, 0.5747],
[ 0.1614, 0.3012, -0.4626]])
tensor([[0.0000e+00, 0.0000e+00, 3.7253e-09],
[0.0000e+00, 0.0000e+00, 0.0000e+00],
[1.4901e-08, 0.0000e+00, 5.9605e-08]])
I think this version of OIM loss requires being perfected, because it seems that this code is widely used in person search. Thanks!
Why in the OIM loss we need to define the class OIM which extends Function class ? Normally pytorch does the backward for us, so we don't need to write the backward function. Thank you.