I'm trying to re-produce your recent paper on splatting-based synthesis for video frame interpolation and it was really nice work that inspires me a lot. But I'm stuck at implementing numerically stable softsplat you mentioned in Section 3, where you said that "warp Z0 to time t as Zmax ... this step is and need not be differentiable ...". I'd be appreciated if you could further clarify the following two questions:
how to implement the necessary "backward" function of torch.autograd.Function to calculate Zmax in training process. I've implemented the following snippet to calculate Zmax and it works well,
class softsplat_zmax_func(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(self, tenIn, tenFlow):
tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) # max weight
along with some modification on the `softsplat` function
```python
...
elif strMode.split('-')[0] == 'soft':
tenMetricMax = softsplat_zmax_func.apply(tenMetric, tenFlow)
tenMetric = torch.exp(tenMetric - tenMetricMax)
# tenMetric = torch.exp(tenMetric)
tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
...
it's fine for inference but and I can't figure out how to design the backward function for softsplat_zmax_func since it requires some gradient so as not to mess up the training.
I notice that atomic max of cupy does not support float operation, while I notice you said that "This can be efficiently computed in parallel using an atomic max". Could you please share with us how you handled this?
Hi Simon,
I'm trying to re-produce your recent paper on splatting-based synthesis for video frame interpolation and it was really nice work that inspires me a lot. But I'm stuck at implementing numerically stable softsplat you mentioned in Section 3, where you said that "warp Z0 to time t as Zmax ... this step is and need not be differentiable ...". I'd be appreciated if you could further clarify the following two questions:
torch.autograd.Function
to calculate Zmax in training process. I've implemented the following snippet to calculate Zmax and it works well,class softsplat_zmax_func(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) def forward(self, tenIn, tenFlow): tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) # max weight
it's fine for inference but and I can't figure out how to design the backward function for softsplat_zmax_func since it requires some gradient so as not to mess up the training.
Thanks in advance!