wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
623 stars 148 forks source link

ComplexConvTransposeNd #26

Open H320 opened 1 year ago

H320 commented 1 year ago

https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Convolution.cpp#L812

Complex Convolution conv(W, x, b) = conv(Wr, xr, br) - conv(Wi, xi, 0) + i(conv(Wi, xr, bi) + conv(Wr, xi, 0)) where W, x and b are all complex inputs. With Gauss Trick: a = conv(Wr, xr, br), b = conv(Wi, xi, 0), c = conv(Wr + Wi, xr + xi, bi + br) conv(W, x, b) = a - b + i(c - a - b)

from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import Tensor, nn

class ComplexConvTranspose1dn(nn.ConvTranspose1d):

    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')

        assert isinstance(self.padding, tuple)
        # One cannot replace List by Tuple or Sequence in "_output_padding" because
        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
        num_spatial_dims = 1
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
            num_spatial_dims, self.dilation)  # type: ignore[arg-type]

        i_r = input.real
        i_i = input.imag
        w_r = self.weight.real
        w_i = self.weight.imag
        b_r = self.bias.real
        b_i = self.bias.imag

        a = F.conv_transpose1d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
        b = F.conv_transpose1d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
        c = F.conv_transpose1d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)

        return torch.complex(a - b, c - a - b)

class ComplexConvTranspose2dn(nn.ConvTranspose2d):

    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')

        assert isinstance(self.padding, tuple)
        # One cannot replace List by Tuple or Sequence in "_output_padding" because
        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
        num_spatial_dims = 2
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
            num_spatial_dims, self.dilation)  # type: ignore[arg-type]

        i_r = input.real
        i_i = input.imag
        w_r = self.weight.real
        w_i = self.weight.imag
        b_r = self.bias.real
        b_i = self.bias.imag

        a = F.conv_transpose2d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
        b = F.conv_transpose2d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
        c = F.conv_transpose2d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)

        return torch.complex(a - b, c - a - b)

class ComplexConvTranspose3dn(nn.ConvTranspose3d):

    def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
        if self.padding_mode != 'zeros':
            raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')

        assert isinstance(self.padding, tuple)
        # One cannot replace List by Tuple or Sequence in "_output_padding" because
        # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
        num_spatial_dims = 3
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size,  # type: ignore[arg-type]
            num_spatial_dims, self.dilation)  # type: ignore[arg-type]

        i_r = input.real
        i_i = input.imag
        w_r = self.weight.real
        w_i = self.weight.imag
        b_r = self.bias.real
        b_i = self.bias.imag

        a = F.conv_transpose3d(i_r, w_r, b_r, self.stride, self.padding, output_padding, self.groups, self.dilation)
        b = F.conv_transpose3d(i_i, w_i, None, self.stride, self.padding, output_padding, self.groups, self.dilation)
        c = F.conv_transpose3d(i_r + i_i, w_r + w_i, b_r + b_i, self.stride, self.padding, output_padding, self.groups, self.dilation)

        return torch.complex(a - b, c - a - b)