dome272 / Diffusion-Models-pytorch

Pytorch implementation of Diffusion Models (https://arxiv.org/pdf/2006.11239.pdf)
Apache License 2.0
1.11k stars 256 forks source link

Time embedding #27

Open Chirobocea opened 1 year ago

Chirobocea commented 1 year ago

Hi! Great code, thanks for sharing!

I noticed something a little bit weird in this code. Is there any reason why you choose to use SiLU() right after the sinusoidal embedding? It seams unnatural as it might change the desired properties of the embedding.

Maybe you missed to use a learnable projection of embedding like adding this to U-net self.time_embed = nn.Sequential( nn.Linear(time_dim, time_dim), nn.SiLU(), nn.Linear(time_dim, time_dim), )

And also changing the forward by adding: def forward(self, x, t): t = t.unsqueeze(-1).type(torch.float) t = self.pos_encoding(t, self.time_dim) t = self.time_embed(t)

In this conditions, the SiLU() activation for projections of each block make sense, being at all just the activation of the learned embedding.

dome272 commented 1 year ago

Thank you for the catch. You are right, that would not make much sense the way it's done and probably gives worse results.