sniklaus / softmax-splatting

an implementation of softmax splatting for differentiable forward warping using PyTorch
466 stars 58 forks source link

Some questions about model and fine-tuning #5

Closed linchuming closed 4 years ago

linchuming commented 4 years ago

Nice work! But, some questions trouble me after I read the paper:

  1. GridNet has only one input and the outputs of feature pyramid extractor are three features with different space sizes. How to combine the outputs to the GridNet?
  2. How to fine-tune the PWC-Net with your model? Do I need add additional loss?
  3. Feature pyramid will generate features of three scales. To splat the three features to the target frame, we need three important metrics corresponding to three features. According to your paper, we can produce the important metric of original scale. So is the low scale metric obtained by downsampling the original scale metric?
sniklaus commented 4 years ago

Thank you for your interest in our paper!

  1. We have multiple inputs to the GridNet, one per level as shown in the illustration in the figure below. For your convenience, I also added the number of output features for each block.

Screenshot_2020-04-14_09-29-59

  1. We initially take an off-the-shelf PWC-Net and only optimize the parameters of our frame synthesis network and our feature pyramid extractor (and the scale parameters for each softmax splatting operator, to be precise). When fine-tuning PWC-Net, we do not change our loss, all we do is additionally backpropagating though PWC-Net and updating its weights too.
  2. We compute the importance metric only at the highest level and then downsample it to obtain the importance metric for the lower levels. My apologies that the paper did not make this clear. You could also compute the importance metric separately for each layer, it has been a while since I tried this configuration but it performed similarly well if I remember correctly (I didn't try this configuration with fine-tuning the importance metric though).
linchuming commented 4 years ago

@sniklaus Thanks for your detailed response! I tried to train the interpolation model with PWC-Net without adding additional loss, but the PWC-Net is unable to converge. Should I adjust the learning rate for PWC-Net or train interpolation model several epochs firstly and then update PWC-Net weights?

Another problem is that softmax splatting includes a small U-Net with three levels to refine the important map . So does each softmax splatting module has a independent U-Net?

sniklaus commented 4 years ago

Would you mind elaborating on your current state of progress? Have you implemented and trained the frame synthesis network as well as the feature pyramid extractor and you know want to fine-tune PWC-Net after successfully having trained the rest of the pipeline already?

I just checked the importance metric calculation again and noticed that we only compute it for the highest level and then downsample it to obtain the importance metric for the lower levels. I accordingly updated my response above. My apologies for having it mixed up, it has been almost two years since I wrote the code for this project.

Regarding the U-Net to fine-tune the importance metric, please see my response in #4 which you may find useful. In there, I recommend to first start trying without using this component since the effect of fine-tuning the importance metric is minor anyway.

linchuming commented 4 years ago

@sniklaus Thanks for your response again! I tired to reimplement your softspalting approach. I use some configs as follows:

  1. feature pyramid extractor and GridNet as your paper.
  2. compute important metric by images and downsample it for lower levels.
  3. use the Laplacian loss:

    
    class LaplacianPyramid(nn.Module):
    def __init__(self, max_level=5):
        super(LaplacianPyramid, self).__init__()
        self.gaussian_conv = GaussianConv()
        self.max_level = max_level
    
    def forward(self, X):
        t_pyr = []
        current = X
        for level in range(self.max_level):
            t_guass = self.gaussian_conv(current)
            t_diff = current - t_guass
            t_pyr.append(t_diff)
            current = F.avg_pool2d(t_guass, 2)
        t_pyr.append(current)
    
        return t_pyr

class LaplacianLoss(nn.Module): def init(self): super(LaplacianLoss, self).init()

    self.criterion = nn.L1Loss()
    self.lap = LaplacianPyramid()

def forward(self, x, y):
    x_lap, y_lap = self.lap(x), self.lap(y)
    return sum(self.criterion(a, b) for a, b in zip(x_lap, y_lap))

4. use the pretrained PWC-Net as you published:
https://github.com/sniklaus/pytorch-pwc

5. without finetuning the PWC-Net,

6. patch size is 256x256, random flip horizontally and vertically, random temporal order

7. train the model for 50 epochs, learning rate is 1e-4

Finally, I got PSNR 34.96 in Vimeo90k testing dataset. As your paper mentioned, I should get PSNR 35.54.
Therefore, could you tell me what should I change my configs. Thanks again!
sniklaus commented 4 years ago

Looks like you got pretty good results already, congrats!

  1. You could try weighting the levels differently, for example giving the coarser levels in the pyramid more weight. Your mileage may vary, there is definitely room for improvement though.
  2. Notice that we used a PWC-Net pre-trained on FlyingChairs, which is not an official model. So you may get worse or better results than us.
  3. You might want to try a few to see what works best for you. Lowering the learning rate towards the end of your training may be worth a shot, too.

Good luck!