AndyCao1125 / SDDPM

[WACV 2024] Spiking Denoising Diffusion Probabilistic Models
30 stars 3 forks source link

Conditional generation #3

Open 1439278026 opened 3 months ago

1439278026 commented 3 months ago

Thank you for your code. It was an excellent job. In your paper and code, I noticed that sample generation is unconditional. If I want to achieve conditional generation, what do I need to do, such as conducting experiments on the mnist dataset, and the trained model can generate the single numerical sample I want,

AndyCao1125 commented 3 months ago

Thanks for your question. Actually, our spiking UNet can also achieve conditional generation, only minor changes to the model are required. Here is an example of a simple conditional SpikingUnet:

class Spk_UNet_cond(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout, timestep, img_ch=3, num_classes=None):
        super().__init__()
        if num_classes is not None:
            time_dim = 4*ch
            self.label_emb = nn.Embedding(num_classes, time_dim)
        ...

    def forward(self, x, t, y=None):
        x = x.unsqueeze(0).repeat(self.timestep, 1, 1, 1, 1)  # [T, B, C, H, W]

        # Timestep embedding
        temb = self.time_embedding(t)
        if y is not None:
            label_emb = self.label_embedding(label)
            temb +=label_emb
        ...

The conditional model is almost identical but adds the encoding of the class label into the timestep by passing the label through an Embedding layer. More examples or details can be found in many github repositories or tutorials (e.g., How To Train a Conditional Diffusion Model From Scratch, W&B ). Adding condition into the spiking-based diffusion process would be an interesting study for future work.

1439278026 commented 3 months ago

Thank you for your reply. I am currently making some attempts

1439278026 commented 2 months ago

May I ask which methods of encoding category information you have tried, and how to integrate category information into the network? I have tried to add and concatenate it with timeembedding, but the result shows that it has failed

AndyCao1125 commented 2 months ago

A common method for conditional image generation, as detailed in the above example codes and the tutorial previously mentioned, involves using a single nn.Embedding layer to obtain the category embedding and then adding it to the time embedding. We did not present the conditioned generation results in our paper, but this topic is certainly worth further exploration :)