yang-song / score_sde_pytorch

PyTorch implementation for Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021, Oral)
https://arxiv.org/abs/2011.13456
Apache License 2.0
1.68k stars 309 forks source link

Round operation for discrete models #2

Closed SANCHES-Pedro closed 3 years ago

SANCHES-Pedro commented 3 years ago

Hello,

Firstly, congratulations on the amazing work. The ICLR award was well deserved!

I don't want to be pedantic but I realized that the get_score_fn for discrete models doesn't have a torch.round() operation even though the t at training time is an int. Therefore, the sampling is being done with slightly different values than the training (e.g. 500.1 instead of 500). I'm not sure if this really affects performance, it's just an observation.

I would add labels = torch.round(labels) after line 155 of the models/utils.py file.

Many thanks, Pedro

yang-song commented 3 years ago

Thanks for the comment! You are right it is better to add an additional torch.round operation for discrete models, though I don't think results would change much. The current code has an additional benefit: you can use the continuous SDE framework even for models pre-trained with discrete losses (such as DDPM and NCSN models provided by previous work), which allows you to compute log-likelihoods, for example.

SANCHES-Pedro commented 3 years ago

I see, I hadn't understood the motivation of not adding that. Thanks for clarifying!