sniklaus / softmax-splatting

an implementation of softmax splatting for differentiable forward warping using PyTorch
466 stars 58 forks source link

Question of technical implemention details on Z^max (Equation 3) #60

Open Justin62628 opened 1 year ago

Justin62628 commented 1 year ago

Hi Simon,

I'm trying to re-produce your recent paper on splatting-based synthesis for video frame interpolation and it was really nice work that inspires me a lot. But I'm stuck at implementing numerically stable softsplat you mentioned in Section 3, where you said that "warp Z0 to time t as Zmax ... this step is and need not be differentiable ...". I'd be appreciated if you could further clarify the following two questions:

  1. how to implement the necessary "backward" function of torch.autograd.Function to calculate Zmax in training process. I've implemented the following snippet to calculate Zmax and it works well,

class softsplat_zmax_func(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) def forward(self, tenIn, tenFlow): tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) # max weight

    if tenIn.is_cuda == True:
        cuda_launch(cuda_kernel('zmax_out', '''

            __device__ __forceinline__ float atomicMinFloat(float* addr, float value) {
                float old;
                old = !signbit(value) ? __int_as_float(atomicMin((int*)addr, __float_as_int(value))) :
                    __uint_as_float(atomicMax((unsigned int*)addr, __float_as_uint(value)));

                return old;
            }

            __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
                float old;
                old = !signbit(value) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
                    __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));

                return old;
            }

            extern "C" __global__ void __launch_bounds__(512) zmax_out(
                const int n,
                const {{type}}* __restrict__ tenIn,  // Z input only, B 1 H W
                const {{type}}* __restrict__ tenFlow,
                {{type}}* __restrict__ tenOut  // Z max output
            ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
                const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut);
                const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut)                  ) % SIZE_1(tenOut);
                const int intY = ( intIndex / SIZE_3(tenOut)                                   ) % SIZE_2(tenOut);
                const int intX = ( intIndex                                                    ) % SIZE_3(tenOut);

                assert(SIZE_1(tenFlow) == 2);

                {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX);
                {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX);

                if (isfinite(fltX) == false) { return; }
                if (isfinite(fltY) == false) { return; }

                {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX);

                int intNorthwestX = (int) (floor(fltX));
                int intNorthwestY = (int) (floor(fltY));
                int intNortheastX = intNorthwestX + 1;
                int intNortheastY = intNorthwestY;
                int intSouthwestX = intNorthwestX;
                int intSouthwestY = intNorthwestY + 1;
                int intSoutheastX = intNorthwestX + 1;
                int intSoutheastY = intNorthwestY + 1;

                /*
                for (int i = intNorthwestX - 1; i < intNorthwestX + 3; i++)
                {
                    for (int j = intNorthwestY - 1; j < intNorthwestY + 3; j++)
                    {
                        if ((i >= 0) && (i < SIZE_3(tenOut)) && (j >= 0) && (j < SIZE_2(tenOut))) {
                            atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, j, i)], fltIn);
                        }
                    }
                } 
                */

                if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) {
                    atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn);
                }

                if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) {
                    atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn);
                }

                if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) {
                    atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn);
                }

                if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) {
                    atomicMaxFloat(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn);
                }

            } }
        ''', {
            'tenIn': tenIn,
            'tenFlow': tenFlow,
            'tenOut': tenOut
        }))(
            grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]),
            block=tuple([512, 1, 1]),
            args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()],
            stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream)
        )

    elif tenIn.is_cuda != True:
        assert (False)

    # end

    self.save_for_backward(tenIn, tenFlow)

    return tenOut

# end
along with some modification on the `softsplat` function
```python
...
    elif strMode.split('-')[0] == 'soft':
        tenMetricMax = softsplat_zmax_func.apply(tenMetric, tenFlow)
        tenMetric = torch.exp(tenMetric - tenMetricMax)
        # tenMetric = torch.exp(tenMetric)
        tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1)
...

it's fine for inference but and I can't figure out how to design the backward function for softsplat_zmax_func since it requires some gradient so as not to mess up the training.

  1. I notice that atomic max of cupy does not support float operation, while I notice you said that "This can be efficiently computed in parallel using an atomic max". Could you please share with us how you handled this?

Thanks in advance!