InterDigitalInc / CompressAI

A PyTorch library and evaluation platform for end-to-end compression research
https://interdigitalinc.github.io/CompressAI/
BSD 3-Clause Clear License
1.15k stars 228 forks source link

Issue Regarding the Use of GaussianMixtureConditional #289

Closed formioq closed 3 months ago

formioq commented 3 months ago

I want to use the GaussianMixtureConditional component from the library, but I found that its usage seems to conflict with its parent class, GaussianConditional. Specifically, in the likelihood function of GaussianMixtureConditional, the line M = inputs.size(1) and the subsequent slicing behavior appear to assume a single Gaussian distribution (as it slices from 0 to the channel size of inputs, which can only be split into one part). When I modify it to M = inputs.size(1) // self.K to enable slicing into K parts (the number of Gaussian distributions), it causes a channel dimension mismatch in the parent class's likelihood function, specifically in values = inputs - means (since only means are sliced, but not inputs). I would like to know if my understanding of the code is incorrect or if it actually cannot handle multiple Gaussian distributions?

2 1
YodaEmbedding commented 3 months ago

M is the number of channels in y.

B, M, H, W = y.shape

Please ensure:

assert scales.shape[1] == M * K
assert means.shape[1] == M * K
formioq commented 3 months ago

Thank you very much for your answer; it is now working correctly. However, I noticed that the bpp (calculated using your RateDistortionLoss) seems to have increased significantly (at lambda=0.01, the bpp was around 0.45 when using GC, but it is over 10 after using GMM). Do you know a solution to this problem?

YodaEmbedding commented 3 months ago

Are you training with the same settings? After the first epoch of training, bpp_loss should usually be ≤ 3.

The only difference between the two experiments should be:

What architecture is your model based on (e.g., mbt2018-mean, cheng2020-anchor, elic2022-chandelier, ...)?

formioq commented 3 months ago

Yes, I am using the 'mbt2018' as the base model (with a context model), but I modified the output of entropy_parameters to be M9 because GMM seems to require an additional weight parameter. I used chunk(3,1) to divide it into scales_hat, means_hat, and weight_hat, each part having M3 channels to facilitate subsequent modeling with K=3 GMM. All other settings remained unchanged in both experiments. Here is the modified code (with the commented-out part being the changes for running under GMM conditions,I also used softmax to ensure the weights sum to 1). Could it be an issue related to the weight settings? Or is there a logical misunderstanding on my part? I look forward to your response. `
r"""

..
              ┌───┐    y     ┌───┐  z  ┌───┐ z_hat      z_hat ┌───┐
        x ──►─┤g_a├──►─┬──►──┤h_a├──►──┤ Q ├───►───·⋯·───►───┤h_s├─┐
              └───┘    │     └───┘     └───┘        EB        └───┘ │
                       ▼                                            │
                     ┌─┴─┐                                          │
                     │ Q │                                   params ▼
                     └─┬─┘                                          │
                 y_hat ▼                  ┌─────┐                   │
                       ├──────────►───────┤  CP ├────────►──────────┤
                       │                  └─────┘                   │
                       ▼                                            ▼
                       │                                            │
                       ·                  ┌─────┐                   │
                    GC : ◄────────◄───────┤  EP ├────────◄──────────┘
                       ·     scales_hat   └─────┘
                       │      means_hat
                 y_hat ▼
                       │
              ┌───┐    │
    x_hat ──◄─┤g_s├────┘
              └───┘

    EB = Entropy bottleneck
    GC = Gaussian conditional
    EP = Entropy parameters network
    CP = Context prediction (checkboard)
"""

def init(self, N, M, kwargs): super().init(kwargs)

    self.entropy_bottleneck = EntropyBottleneck(N)

    self.g_a = nn.Sequential(
        conv(3, N),
        GDN(N),
        conv(N, N),
        GDN(N),
        conv(N, N),
        GDN(N),
        conv(N, M),
    )

    self.g_s = nn.Sequential(
        deconv(M, N),
        GDN(N, inverse=True),
        deconv(N, N),
        GDN(N, inverse=True),
        deconv(N, N),
        GDN(N, inverse=True),
        deconv(N, 3),
    )

    self.h_a = nn.Sequential(
        conv(M, N, stride=1, kernel_size=3),
        nn.ReLU(inplace=True),
        conv(N, N),
        nn.ReLU(inplace=True),
        conv(N, N),
    )

    self.h_s = nn.Sequential(
        deconv(N, N),
        nn.ReLU(inplace=True),
        deconv(N, M*3//2),
        nn.ReLU(inplace=True),
        conv(M*3//2, M *2 , stride=1, kernel_size=3),
        nn.ReLU(inplace=True),
    )

    self.entropy_parameters = nn.Sequential(
        nn.Conv2d(M * 12 // 3, M * 10 // 3, 1),
        nn.LeakyReLU(inplace=True),
        nn.Conv2d(M * 10 // 3, M * 8 // 3, 1),
        nn.LeakyReLU(inplace=True),
        nn.Conv2d(M * 8 // 3, M * 6 //3 , 1),

        #nn.Conv2d(M * 8 // 3, M * 9 , 1),

    )

    self.context_prediction = CheckerboardMaskedConv2d(
        M, 2 * M, kernel_size=5, padding=2, stride=1,mask_type='B'
    )
    self.gaussian_conditional = GaussianConditional(None)

    #self.gaussian_conditional = GaussianMixtureConditional(K=3)

    self.N = int(N)
    self.M = int(M)

@property
def downsampling_factor(self) -> int:
    return 2 ** (4 + 2)

def forward(self, x):
    y = self.g_a(x)
    z = self.h_a(torch.abs(y))
    z_hat, z_likelihoods = self.entropy_bottleneck(z)

    params=self.h_s(z_hat)
    y_hat = self.gaussian_conditional.quantize(y,"noise" if self.training else "dequantize"
    )
    ctx_hat=self.context_prediction(y_hat)
    gaussian_params = self.entropy_parameters(
        torch.cat((params, ctx_hat), dim=1)
    )

    #scales_hat, means_hat, weight_hat = gaussian_params.chunk(3, 1)

    scales_hat, means_hat = gaussian_params.chunk(2, 1)

    #weight_hat = F.softmax(weight_hat, dim=1)

    y_hat1, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)

    #y_hat1, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat, weights=weight_hat)

    x_hat = self.g_s(y_hat1)

    return {
        "x_hat": x_hat,
        "likelihoods": {"y": y_likelihoods, "z": z_likelihoods},
    }`
YodaEmbedding commented 3 months ago

The weight_hat softmax should be done along a dimension of length 3.

You may need to reshape and apply along the 3-length dimension and then reshape back so that it's compatible with the GCM interface. Perhaps:

B, Mx3, H, W = weight_hat.shape

weight_hat = F.softmax(
    weight_hat.reshape(B, 3, M, H, W), dim=1
).reshape(B, Mx3, H, W)
formioq commented 3 months ago

The problem is solved, thank you very much, you are really amazing