torch / torch7

http://torch.ch
Other
9k stars 2.38k forks source link

How to extend the MSECriterion such that loss is only computed where the ground-truth label is non zero. #1055

Open tempdelhi123 opened 7 years ago

tempdelhi123 commented 7 years ago

Hello, For my application the ground truth label is only available sparsely and is zero elsewhere. Therefore I wanted to extend the MSECriterion such that loss is only computed where the ground-truth label is non zero. Any idea on how to do this or pointers to any tutorials would be very helpful.

tastyminerals commented 7 years ago

Check torch/extra/nn/MSECriterion.lua file. When you perform forward() MSECriterion:updateOutput(input, target) is called. When you perform backward() MSECriterion:updateGradInput(input, target) is called. Modify these too methods accordingly. See my example variant below which keeps additional self.prev_output and self.prev_gradInput fields to store the previous MSE outputs. If your target is 0 we return previously computed MSE error if not then we compute MSE error the normal way. Adjust the example below to your needs, save it locally in myproject/criterions and load it as require myproject.criterions.MSECriterion.

local MSECriterion, parent = torch.class('MSECriterion', 'nn.Criterion')

function MSECriterion:__init(sizeAverage)
  parent.__init(self)
  self.prev_output = 0
  self.prev_gradInput = 0
  if sizeAverage ~= nil then
    self.sizeAverage = sizeAverage
  else
    self.sizeAverage = true
  end
end

function MSECriterion:updateOutput(input, target)
  if torch.sum(target) ~= 0 then 
    self.output_tensor = self.output_tensor or input.new(1)
    input.THNN.MSECriterion_updateOutput(
      input:cdata(),
      target:cdata(),
      self.output_tensor:cdata(),
      self.sizeAverage
    )
    self.output = self.output_tensor[1]
    self.prev_output = self.output
    return self.output
  else
    return self.prev_output
  end
end

function MSECriterion:updateGradInput(input, target)
  if torch.sum(target) ~= 0 then
    input.THNN.MSECriterion_updateGradInput(
      input:cdata(),
      target:cdata(),
      self.gradInput:cdata(),
      self.sizeAverage
    )
    self.prev_gradInput = self.gradInput
    return self.gradInput
  else
    return self.prev_gradInput
  end
end