Closed ds-hwang closed 16 hours ago
This PR implements unified transpose convolution covering 1D/2D/3D, SAME/VALID/CAUSAL and arbitrary padding, arbitrary window, stride, and dilation. SAME and VALID is equivalent to jax.lax.conv_transpose(). CAUSAL is defined in this PR. Each Literal padding follows the formulas below, * SAME: padding=(min(window-1, ceil((w+s-2)/2)), max(stride-1, floor((w+s-2)/2))) pad_total = window+stride-2 when stride > window -> (window-1, stride-1) * VALID: padding=(window-1, max(stride-1, window-1)) pad_total = window+stride-2 + max(window-stride, 0) when stride > window -> (window-1, stride-1) * CAUSAL: padding=(window-1, stride-1) pad_total = window+stride-2 Note: output_size = input_size*stride - (window+stride-2) + pad_total = input_size*stride <- "SAME", "CAUSAL" = input_size*stride + max(window-stride, 0) <- "VALID" Note: In the above equation, `window` can be replaced with `dilate_window` when dilation > 1. dilate_window = (window - 1) * dilation + 1. Check conv_dilate_window() The following illustration demonstrates how Conv Transpose operates, assuming all kernel values are set to 1 for simplicity in showcasing output values. In the window=3 and stride=1 case, this function creates outputs as follows: * "SAME" padding=(1, 1) pad| |pad paddings: 0|0 0 1 1|0 0 0 0 -> 0 0 0 1 -> 1 0 1 1 -> 2 1 1 0 -> 2 * "VALID" padding=(2, 2) pad | |pad paddings: 0 0|0 0 1 1|0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 1 -> 2 1 1 0 -> 2 1 0 0 -> 1 * "CAUSAL" padding=(2, 0) pad | |pad paddings: 0 0|0 0 1 1| 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 1 -> 2 In the window=3 and stride=2 case, this function creates outputs as follows: * "SAME" padding=(2, 1) pad | |pad paddings: 0 0|0 * 0 * 1 * 1|0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 1 -> 2 0 1 0 -> 1 * "VALID" padding=(2, 2) pad | |pad paddings: 0 0|0 * 0 * 1 * 1|0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 1 -> 2 0 1 0 -> 1 1 0 0 -> 1 * "CAUSAL" padding=(2, 1) pad | |pad paddings: 0 0|0 * 0 * 1 * 1|0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 1 -> 2 0 1 0 -> 1 In the window=3 and stride=3 case, this function creates outputs as follows: * "SAME", "VALID" and "CAUSAL" padding=(2, 2) pad | |pad paddings: 0 0|0 * * 0 * * 1 * * 1|0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 In the window=3 and stride=4 case, this function creates outputs as follows: * "SAME", "VALID" and "CAUSAL" padding=(2, 3) pad | |pad paddings: 0 0|0 * * * 0 * * * 1 * * * 1|0 0 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 0 0 0 -> 0 0 0 1 -> 1 0 1 0 -> 1 1 0 0 -> 1 0 0 0 -> 0 Here is how to compute output_size, given the above example, 1. |_| -(window-1) 2. |_______________________| (input_size-1)*stride + 1 3. |_| |___| + pad_total So, output_size = -(window-1) + (input_size-1)*stride + 1 + pad_total = input_size*stride - (window+stride-2) + pad_total = input_size*stride <- "SAME", "CAUSAL" = input_size*stride + max(window-stride, 0) <- "VALID" OTHO, when dilation > 1, dilate_window = (window - 1) * dilation + 1. For example, when window=3 and dilation=2, dilate_window=5. In the stride=2 case, this function creates outputs as follows: * "SAME" padding=(3, 2) pad | |pad paddings: 0 0 0|0 * 0 * 1 * 1|0 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 1 -> 1 0 * 0 * 0 -> 0 0 * 1 * 1 -> 2 0 * 0 * 0 -> 0 1 * 1 * 0 -> 2 * "VALID" padding=(4, 4) pad | |pad paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 0 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 1 -> 1 0 * 0 * 0 -> 0 0 * 1 * 1 -> 2 0 * 0 * 0 -> 0 1 * 1 * 0 -> 2 0 * 0 * 0 -> 0 1 * 0 * 0 -> 1 * "CAUSAL" padding=(4, 1) pad | |pad paddings: 0 0 0 0|0 * 0 * 1 * 1|0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 0 -> 0 0 * 0 * 1 -> 1 0 * 0 * 0 -> 0 0 * 1 * 1 -> 2 0 * 0 * 0 -> 0
@ruomingp could you approve it? from 926
Thank you for review!