ClementPinard / FlowNetTorch

Torch implementation of Fischer et al. FlowNet training code
30 stars 6 forks source link

computing EPEs on flowfiled with unvalid values #16

Closed SteveSZF closed 5 years ago

SteveSZF commented 6 years ago

On some optical flow datasets, there are many unvalid flow values on the flowfields. When we train the network on these datasets, do we need to change the way of computing EPEs to exclude the unvalid flow values during forward and backward propagation? It seems that you don't exclude the unvalid flow values when you compute EPE in your code.

ClementPinard commented 6 years ago

what are the invalid valus like ? are they NaNs ? I am not very familiar wth lua anymore (it's been a year !), but you can do a masked select of values on which you will compute a mean with a ByteTensor. Explanation here note that x:maskedSelect(mask) is the same as x[mask] provided mask is the right type. Try this code for EPE forward function (although I'm not 100% sure it works, torch is not installed on my computer anymore)

function EPECriterion:updateOutput(input, target)
   local diffMap = input-target
   assert(input:nDimension() == 4 or input:nDimension() == 3)
   if input:nDimension() == 4 then
     local valid_pixels = torch.eq(target, target)[{{},1}]  -- NaN are characterized as NaN != NaN, this is a tensor or BxHxW boolean values
     self.EPE = diffMap:norm(2,2)[valid_pixels]:view(input:size(1), -1) -- get a tensor of Bx(H*W - nb_of invalid flow values) EPE values
   else
     local valid_pixels = torch.eq(target, target)[1] -- this is a tensor or HxW boolean values
     self.EPE = diffMap:norm(2,1)[valid_pixels]
   end
   self.zeroEPE = torch.zeros(self.EPE:size()):cuda():fill(0)
   self.output = self.criterion:forward(self.EPE, self.zeroEPE)
   return self.output
end

you then do more or less the same thing for updateGradInput function

By the way, there is a Pytorch version that is more up to date if you are willing to go from torch to pytorch (and I advise you to do so if not too much of your code is already in torch) invalid flow values are already taken care of (in the case of 0 flow, but it is easily replaced with NaN flow) here : https://github.com/ClementPinard/FlowNetPytorch https://github.com/ClementPinard/FlowNetPytorch/blob/master/multiscaleloss.py#L7