Open tempdelhi123 opened 7 years ago
Check torch/extra/nn/MSECriterion.lua
file. When you perform forward()
MSECriterion:updateOutput(input, target)
is called. When you perform backward()
MSECriterion:updateGradInput(input, target)
is called. Modify these too methods accordingly.
See my example variant below which keeps additional self.prev_output
and self.prev_gradInput
fields to store the previous MSE outputs. If your target is 0 we return previously computed MSE error if not then we compute MSE error the normal way. Adjust the example below to your needs, save it locally in myproject/criterions
and load it as require myproject.criterions.MSECriterion
.
local MSECriterion, parent = torch.class('MSECriterion', 'nn.Criterion')
function MSECriterion:__init(sizeAverage)
parent.__init(self)
self.prev_output = 0
self.prev_gradInput = 0
if sizeAverage ~= nil then
self.sizeAverage = sizeAverage
else
self.sizeAverage = true
end
end
function MSECriterion:updateOutput(input, target)
if torch.sum(target) ~= 0 then
self.output_tensor = self.output_tensor or input.new(1)
input.THNN.MSECriterion_updateOutput(
input:cdata(),
target:cdata(),
self.output_tensor:cdata(),
self.sizeAverage
)
self.output = self.output_tensor[1]
self.prev_output = self.output
return self.output
else
return self.prev_output
end
end
function MSECriterion:updateGradInput(input, target)
if torch.sum(target) ~= 0 then
input.THNN.MSECriterion_updateGradInput(
input:cdata(),
target:cdata(),
self.gradInput:cdata(),
self.sizeAverage
)
self.prev_gradInput = self.gradInput
return self.gradInput
else
return self.prev_gradInput
end
end
Hello, For my application the ground truth label is only available sparsely and is zero elsewhere. Therefore I wanted to extend the MSECriterion such that loss is only computed where the ground-truth label is non zero. Any idea on how to do this or pointers to any tutorials would be very helpful.