Closed mathmanu closed 6 years ago
In other words, what I am suggesting is to modify the the forward() function of FlowNetS as follows:
def forward(self, x):
_,_,h,w = x.size()
.
# existing code here
.
.
flow0 = nn.functional.upsample(flow2, size=(h, w), mode='bilinear')
if self.training:
return [flow0,flow2,flow3,flow4,flow5]
else:
return flow0
Hmm, That could be a good thing for sparse, but I don't like the idea of treating FlowNetSUp as a different network. You should be able to use a pretrained FlowNetS as a FlowNetSUp as you want without having to convert it.
I get the problem and its solution, but it should be in the learning process rather than in the network definition that we try to learn full res flow. I'll see what can be done nicely :)
Sure. It should be possible to put the upsampling outside the network definition - just before calling the loss function.
Fixed
The current multi scale never computes the loss at the native rsolution. This is because the highest resolution of FlowNetS (flow2) is smaller then the target or input resolution.
The problem is more sever in the case of sparse target (eg. KITTI) since we don't really use an accurate resampling of the of sparse target, but use max pooling instead. I believe this may be because pytorch may not have nearest neighbor resizing with support for flexible output sizes. Even if we had nearest neighbor resizing, that would be inaccurate too.
A quick fix for this would be to do a bilinear upsampling at the output for FlowNetS. Then the error on the first resolution would be computed without any resizing of target.
The following derivative of FlowNetS which I call as FlowNetUp does the same thing. There is really no need to define a new class, this can be incorporated into FlowNetS itself.
If the above line of reasoning is correct then this change should provide improved training and hence better accuracy.
file FlowNetSUp.py
import torch import torch.nn as nn from torch.nn.init import kaiming_normal import math from .FlowNetS import FlowNetS
all = [ 'FlowNetSUp', 'flownets_up', 'flownets_up_bn' ]
class FlowNetSUp(FlowNetS):
def flownets_up(path=None): model = FlowNetSUp(batchNorm=False) if path is not None: data = torch.load(path) if 'state_dict' in data.keys(): model.load_state_dict(data['state_dict']) else: model.load_state_dict(data) return model
def flownets_up_bn(path=None): model = FlowNetSUp(batchNorm=True) if path is not None: data = torch.load(path) if 'state_dict' in data.keys(): model.load_state_dict(data['state_dict']) else: model.load_state_dict(data) return model