apple / axlearn

An Extensible Deep Learning Library
Apache License 2.0
1.88k stars 269 forks source link

Implement ConvXDTranspose #853

Closed ds-hwang closed 16 hours ago

ds-hwang commented 4 days ago
This PR implements unified transpose convolution covering 1D/2D/3D, SAME/VALID/CAUSAL and arbitrary
padding, arbitrary window, stride, and dilation.

SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in this PR.

Each Literal padding follows the formulas below,
* SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2)))
     pad_total = window+stride-2
     when stride > window -> (window-1, stride-1)
* VALID: padding=(window-1, max(stride-1, window-1)) 
     pad_total = window+stride-2 + max(window-stride, 0)
     when stride > window -> (window-1, stride-1)
* CAUSAL: padding=(window-1, stride-1)
     pad_total = window+stride-2

Note: output_size = input_size*stride - (window+stride-2) + pad_total
                  = input_size*stride  <- "SAME", "CAUSAL"
                  = input_size*stride + max(window-stride, 0)  <- "VALID"

Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1.
    dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window()

The following illustration demonstrates how Conv Transpose operates, assuming all kernel values are set
to 1 for simplicity in showcasing output values.

In the window=3 and stride=1 case, this function creates outputs as follows:
* "SAME" padding=(1, 1)
                pad|       |pad
    paddings:     0|0 0 1 1|0
                  0 0 0  -> 0
                    0 0 1  -> 1
                      0 1 1  -> 2
                        1 1 0  -> 2

* "VALID" padding=(2, 2)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2
                          1 1 0  -> 2
                            1 0 0  -> 1

* "CAUSAL" padding=(2, 0)
                pad  |       |pad
    paddings:     0 0|0 0 1 1|
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 1  -> 1
                        0 1 1  -> 2

In the window=3 and stride=2 case, this function creates outputs as follows:
* "SAME" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

* "VALID" padding=(2, 2)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1
                                  1 0 0  -> 1

* "CAUSAL" padding=(2, 1)
                pad  |             |pad
    paddings:     0 0|0 * 0 * 1 * 1|0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 1  -> 1
                            0 1 0  -> 1
                              1 0 1  -> 2
                                0 1 0  -> 1

In the window=3 and stride=3 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 2)
                pad  |                   |pad
    paddings:     0 0|0 * * 0 * * 1 * * 1|0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 1  -> 1
                                0 1 0  -> 1
                                  1 0 0  -> 1
                                    0 0 1  -> 1
                                      0 1 0  -> 1
                                        1 0 0  -> 1

In the window=3 and stride=4 case, this function creates outputs as follows:
* "SAME", "VALID" and "CAUSAL" padding=(2, 3)
                pad  |                         |pad
    paddings:     0 0|0 * * * 0 * * * 1 * * * 1|0 0 0
                  0 0 0  -> 0
                    0 0 0  -> 0
                      0 0 0  -> 0
                        0 0 0  -> 0
                          0 0 0  -> 0
                            0 0 0  -> 0
                              0 0 0  -> 0
                                0 0 0  -> 0
                                  0 0 1  -> 1
                                    0 1 0  -> 1
                                      1 0 0  -> 1
                                        0 0 0  -> 0
                                          0 0 1  -> 1
                                            0 1 0  -> 1
                                              1 0 0  -> 1
                                                0 0 0  -> 0
    Here is how to compute output_size, given the above example,
      1.          |_|  -(window-1)
      2.              |_______________________|  (input_size-1)*stride + 1
      3.          |_|                           |___|  + pad_total

    So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total
                    = input_size*stride - (window+stride-2) + pad_total
                    = input_size*stride  <- "SAME", "CAUSAL"
                    = input_size*stride + max(window-stride, 0)  <- "VALID"

OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. For example, 
when window=3 and dilation=2, dilate_window=5.

In the stride=2 case, this function creates outputs as follows:
* "SAME" padding=(3, 2)
                pad    |             |pad
    paddings:     0 0 0|0 * 0 * 1 * 1|0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 1  -> 1
                          0 * 0 * 0  -> 0
                            0 * 1 * 1  -> 2
                              0 * 0 * 0  -> 0
                                1 * 1 * 0  -> 2

* "VALID" padding=(4, 4)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0 0 0 0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
                                  1 * 1 * 0  -> 2
                                    0 * 0 * 0  -> 0
                                      1 * 0 * 0  -> 1

* "CAUSAL" padding=(4, 1)
                pad      |             |pad
    paddings:     0 0 0 0|0 * 0 * 1 * 1|0
                  0 * 0 * 0  -> 0
                    0 * 0 * 0  -> 0
                      0 * 0 * 0  -> 0
                        0 * 0 * 0  -> 0
                          0 * 0 * 1  -> 1
                            0 * 0 * 0  -> 0
                              0 * 1 * 1  -> 2
                                0 * 0 * 0  -> 0
ds-hwang commented 2 days ago

@ruomingp could you approve it? from 926

ds-hwang commented 16 hours ago

Thank you for review!