ngruver / NOS

Protein Design with Guided Discrete Diffusion
https://arxiv.org/abs/2305.20009
MIT License
116 stars 10 forks source link

Understanding the discrete reverse process #2

Open SagiPolaczek opened 1 year ago

SagiPolaczek commented 1 year ago

Hey!

First, thank you for your contribution to the field as well as open-sourcing your code! Really appreciated! I hopes it's OK that I approach to you here: I want to use D3PM for protein sequences (similar to what you did, with LM) but I'm struggling to understand the following point in the reverse process:

In your paper you've mentioned:

Screenshot 2023-08-17 at 12 39 12

Which theoretically, I agree with. But when it comes to implementing it's not possible to calculate the sum in the last line. While that in the original paper of D3PM they use the mean & log scale to predict that distribution, as far as I understand, in your code you only consider the logits of $p_\theta(x_0 | x_t)$.

More specifically I looked at: MLMDiffusionTransformer.forward():

        sequence_output = self.encoder(embed, encoder_attention_mask=attn_mask)[0]
        prediction_scores = self.cls(sequence_output)

        out = {
            "logits": prediction_scores,
            "sequence_output": sequence_output,
            "embeds": token_embed,
        }

AND MLMDiffusion.forward():

        corrupt_ids, corrupt_mask = (
            self.noise_schedule.corrupt(input_ids, t, corrupt_mask)
        )

        model_output = self.network(
            corrupt_ids,
            t, 
            attn_mask,
        )
        logits = model_output['logits']
        hiddens = model_output['sequence_output']

        loss_fct = nn.CrossEntropyLoss(reduction='none')  # -100 index = padding token
        nll = loss_fct(logits.view(-1, logits.shape[-1]), input_ids.view(-1))

Am I missing something? If not, how is it match the paper?

Thanks a lot! Sagi