First, thank you for your contribution to the field as well as open-sourcing your code! Really appreciated!
I hopes it's OK that I approach to you here: I want to use D3PM for protein sequences (similar to what you did, with LM) but I'm struggling to understand the following point in the reverse process:
In your paper you've mentioned:
Which theoretically, I agree with. But when it comes to implementing it's not possible to calculate the sum in the last line.
While that in the original paper of D3PM they use the mean & log scale to predict that distribution, as far as I understand, in your code you only consider the logits of $p_\theta(x_0 | x_t)$.
More specifically I looked at:
MLMDiffusionTransformer.forward():
Hey!
First, thank you for your contribution to the field as well as open-sourcing your code! Really appreciated! I hopes it's OK that I approach to you here: I want to use D3PM for protein sequences (similar to what you did, with LM) but I'm struggling to understand the following point in the reverse process:
In your paper you've mentioned:
Which theoretically, I agree with. But when it comes to implementing it's not possible to calculate the sum in the last line. While that in the original paper of D3PM they use the mean & log scale to predict that distribution, as far as I understand, in your code you only consider the logits of $p_\theta(x_0 | x_t)$.
More specifically I looked at:
MLMDiffusionTransformer.forward()
:AND
MLMDiffusion.forward()
:Am I missing something? If not, how is it match the paper?
Thanks a lot! Sagi