AaltoML / generative-inverse-heat-dissipation

Code release for the paper Generative Modeling With Inverse Heat Dissipation
https://aaltoml.github.io/generative-inverse-heat-dissipation/
MIT License
58 stars 7 forks source link

The batch dim in dct() function #4

Open kitaev-chen opened 8 months ago

kitaev-chen commented 8 months ago

I feel confused about the implement of dct function.

1

The first problem is about the batch dim N = x_shape[-1]. Whether it is 2D or 3D data which has dim format (N, L, D) or (N, C, H, W), isn’t the batch size N = x_shape[0] ?

def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)
    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = x_shape[-1] # <==============
    x = x.contiguous().view(-1, N)

2

Another question is about the DCTBlur for image data. The image data is 3D which has dim (C, H, W). Why you use 2D instead of 3D?

class DCTBlur(nn.Module):

    def __init__(self, blur_sigmas, image_size, device):
        super(DCTBlur, self).__init__()
        self.blur_sigmas = torch.tensor(blur_sigmas).to(device)
        freqs = np.pi*torch.linspace(0, image_size-1,
                                     image_size).to(device)/image_size
        self.frequencies_squared = freqs[:, None]**2 + freqs[None, :]**2

    def forward(self, x, fwd_steps):
        if len(x.shape) == 4:
            sigmas = self.blur_sigmas[fwd_steps][:, None, None, None]
        elif len(x.shape) == 3:
            sigmas = self.blur_sigmas[fwd_steps][:, None, None]
        t = sigmas**2/2
        dct_coefs = torch_dct.dct_2d(x, norm='ortho') # <==============
        dct_coefs = dct_coefs * torch.exp(- self.frequencies_squared * t)
        return torch_dct.idct_2d(dct_coefs, norm='ortho')
KokeCacao commented 1 month ago

I suppose:

  1. N does not represent batch in this case. It is the dimension to do dct on
  2. MNIST is grey scale therefore image is 2D