Virtsionis / torch-nilm

MIT License
40 stars 12 forks source link

THE VAE IS WRONG #13

Closed zhgqcn closed 1 year ago

zhgqcn commented 2 years ago

image

Virtsionis commented 2 years ago

Hi! I missed that out. Your change was simply to add the corresponding deconv layers right? Thanks a lot for pointing that out! I will make the change ASAP!

PS: The caps weren't necessary... It's a two men open source project and we are trying to help people here.

Virtsionis commented 2 years ago

Hi @zhgqcn! It could save me some time if you posted your changes in the VAE model. Thanks for contributing in our project

zhgqcn commented 1 year ago
class VAE(VIBNet):
    '''
    Architecture introduced in: ENERGY DISAGGREGATION USING VARIATIONAL AUTOENCODERS
    https://arxiv.org/pdf/2103.12177.pdf
    '''
    def __init__(self, window_size=1024, cnn_dim=256, kernel_size=3, latent_dim=16, max_noise=0.1, output_dim=1024, dropout=0.):
        super().__init__()
        self.K = latent_dim
        self.max_noise = max_noise
        self.drop = dropout
        self.window = window_size
        self.dense_input = self.K * (self.window // 4)

        '''
        ENCODER
        '''
        self.conv_seq1 = IBNNet(input_channels=1, output_dim=cnn_dim, kernel_size=kernel_size, max_pool=True, residual=False)
        self.conv_seq2 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, max_pool=True)
        self.conv_seq3 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, max_pool=True)
        self.conv_seq4 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, max_pool=True)
        self.conv_seq5 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, max_pool=True)
        self.conv_seq6 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, max_pool=True, inst_norm=False)
        self.conv_seq7 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, max_pool=False, inst_norm=False)

        '''
        REPARAMETRIZATION TRICK
        '''
        self.flatten1 = nn.Flatten()
        self.dense = LinearDropRelu(self.dense_input, 2 * latent_dim, self.drop)
        self.reshape1 = nn.Linear(self.K, self.window // 64)

        '''
        DECODER
        '''
        self.dconv_seq4 = IBNNet(input_channels=1, output_dim=cnn_dim, kernel_size=kernel_size, inst_norm=False, residual=False, max_pool=False)
        self.deconv1 = nn.ConvTranspose1d(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=2, padding=1, output_padding=1, padding_mode='zeros')

        self.dconv_seq5 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, inst_norm=False, max_pool=False)
        self.deconv2 = nn.ConvTranspose1d(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=2, padding=1, output_padding=1, padding_mode='zeros')

        self.dconv_seq6 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, inst_norm=False, max_pool=False)
        self.deconv3 = nn.ConvTranspose1d(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=2, padding=1, output_padding=1, padding_mode='zeros')

        self.dconv_seq7 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, inst_norm=False, max_pool=False)
        self.deconv4 = nn.ConvTranspose1d(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=2, padding=1, output_padding=1, padding_mode='zeros')

        self.dconv_seq8 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, inst_norm=False, max_pool=False)
        self.deconv5 = nn.ConvTranspose1d(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=2, padding=1, output_padding=1, padding_mode='zeros')

        self.dconv_seq9 = IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, inst_norm=False, max_pool=False)
        self.deconv6 = nn.ConvTranspose1d(in_channels=512, out_channels=256, kernel_size=kernel_size, stride=2, padding=1, output_padding=1, padding_mode='zeros')

        self.dconv_seq10= IBNNet(input_channels=cnn_dim, output_dim=cnn_dim, kernel_size=kernel_size, inst_norm=False, max_pool=False)

        self.outputs = nn.Sequential(
            ConvDropRelu(in_channels=512, out_channels=1, kernel_size=kernel_size)
            # nn.Linear(self.window, output_dim)
        )

    def forward(self, x, current_epoch=1, num_sample=1):
        # x must be in shape [batch_size, 1, window_size]
        x = x.permute(0, 2, 1)

        conv_seq1, pool1 = self.conv_seq1(x)
        conv_seq2, pool2 = self.conv_seq2(pool1)
        conv_seq3, pool3 = self.conv_seq3(pool2)
        conv_seq4, pool4 = self.conv_seq4(pool3)
        conv_seq5, pool5 = self.conv_seq5(pool4)
        conv_seq6, pool6 = self.conv_seq6(pool5)
        conv_seq7, pool7 = self.conv_seq7(pool6)

        flatten1 = self.flatten1(conv_seq7)

        statistics = self.dense(flatten1)
        mu = statistics[:, :self.K]
        # std = F.softplus(statistics[:, self.K:], beta=1)
        std = torch.exp(0.5 * statistics[:, self.K:])
        z = self.reparametrize_n(mu, std, current_epoch, num_sample, self.max_noise)
        reshape1 = self.reshape1(z).unsqueeze(1)

        dconv_seq4, _ = self.dconv_seq4(reshape1)
        dconc5 = torch.cat((dconv_seq4, conv_seq7), 1)
        deconv1 = self.deconv1(dconc5)

        dconv_seq5, _ = self.dconv_seq5(deconv1)
        dconc7 = torch.cat((dconv_seq5, conv_seq6), 1)
        deconv2 = self.deconv2(dconc7)

        dconv_seq6, _ = self.dconv_seq6(deconv2)
        dconc9 = torch.cat((dconv_seq6, conv_seq5), 1)
        deconv3 = self.deconv3(dconc9)

        dconv_seq7, _ = self.dconv_seq7(deconv3)
        dconc11 = torch.cat((dconv_seq7, conv_seq4), 1)
        deconv4 = self.deconv4(dconc11)

        dconv_seq8, _ = self.dconv_seq8(deconv4)
        dconc13 = torch.cat((dconv_seq8, conv_seq3), 1)
        deconv5 = self.deconv5(dconc13)

        dconv_seq9, _ = self.dconv_seq9(deconv5)
        dconc15 = torch.cat((dconv_seq9, conv_seq2), 1)
        deconv6 = self.deconv6(dconc15)

        dconv_seq10, _ = self.dconv_seq10(deconv6)
        dconc17 = torch.cat((dconv_seq10, conv_seq1), 1)

        outputs = self.outputs(dconc17)

        return (mu, std), outputs.permute(0, 2, 1)
oublalkhalid commented 1 year ago

Hi @Virtsionis and @zhgqcn! First, thank you for opening up this interesting discussion on Variational Autoencoder.

I've examined both models (the one offered by @Virtsionis and the revised version from @zhgqcn), and neither of them qualify as a VAE for various reasons. The utilization of skip connections to convey information from the encoder to the decoder appears to contradict the fundamental concept of variational inference, as the KL divergence seems irrelevant for the data due to being entirely skipped over.

image

In summary, I have discovered that the bottleneck of the model has been completely bypassed. I confirmed my findings by inspecting the weights, and my results support this conclusion. Specifically, the connection between dconv_seq9 and conv_seq2 was the most informative in the model. To further validate these results, I suggest adding a noisy value to the z latent space and observing any changes, but unfortunately, this won't affect the model's output. It's important to note that the model violates the minimal definition of a variational autoencoder, as described in the original paper on Energy Disaggregation using Variational autoencoders (https://arxiv.org/pdf/2103.12177.pdf). Another way, investigated is the sampling of z, if we supposed that there are no skips between Encoder $E{\theta}$ and Decoder $D{\phi}$, the sampling of $z$ is wrong. The loss as well, please change it to:

⚠️ I recommend strongly changing the name of this model from VAE to Residual Unet or skipAE, to avoid any confusion .

Cheers