cvignac / DiGress

code for the paper "DiGress: Discrete Denoising diffusion for graph generation"
MIT License
349 stars 73 forks source link

Confused about the sampling process #54

Closed waitma closed 1 year ago

waitma commented 1 year ago

Hello, your work is excellent. after reading the code, i was confused about the sampling, my questions are:

  1. what does the function p_s_and_t_given_0_X compute?
  2. What does the variable "weighte_x" represent in this line of code? weighted_X = pred_X.unsqueeze(-1) * p_s_and_t_given_0_X

Thanks so much for your kindest reply!

cvignac commented 1 year ago

Sampling is done by marginalizing over the network predictions, which corresponds to this formula:

p(xi^{t-1} | G^t) = \int{x_i} p(x_i^{t-1}~|~ x_i, G^t) ~ dp(xi|G^t) = \sum{x \in X} p(x_i^{t-1}~|~ x_i=x, G^t) ~ \hat p^X_i(x) $

p_s_and_t_given_0_X corresponds to the first term, and weighted_x corresponds to the product of the two terms, before we compute the sum.

waitma commented 1 year ago

Thanks! Though your explanation is very clear, I still seem to not fully understand. The function compute_batched_over0_posterior_distribution is used to compute the first term you mentioned earlier, i found that the comment of the function is: """ M: X or E Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0 X_t: bs, n, dt or bs, n, n, dt Qt: bs, d_t-1, dt Qsb: bs, d0, d_t-1 Qtb: bs, d0, dt. """ But in the actual computation process, there is no x0. Is it that my understanding of the computation process is incorrect?

cvignac commented 1 year ago

I struggle a bit to explain this thing, but you need to remember that x0 is simply a column containing a one. If A is a matrix, A @ x0 is a specific line of A.

If we want to compute the value of A @ x0 for all possible x0, we simply need to return the matrix A. That’s why in practice x0 never appears.

On 28 Jun 2023, at 15:00, Codema @.***> wrote:

*

waitma commented 1 year ago

ok, Thank you very much!

chinmay5 commented 1 year ago

A small follow-up on the same. I am struggling with the denominator term in the equation. Can you please explain how the xt.T term appears? Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0

Thanks

Mutual-Luo commented 4 months ago

A small follow-up on the same. I am struggling with the denominator term in the equation. Can you please explain how the xt.T term appears? Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0

Thanks Actually, they don't code in this way. They transpose xt first, and they transpose it later. (A@B)^{T} = B^{T}@A{T}

X_t_transposed = X_t.transpose(-1, -2)      # bs, dt, N
prod = Qtb @ X_t_transposed                 # bs, d0, N
prod = prod.transpose(-1, -2)               # bs, N, d0
llllly26 commented 1 month ago

I struggle a bit to explain this thing, but you need to remember that x0 is simply a column containing a one. If A is a matrix, A @ x0 is a specific line of A. If we want to compute the value of A @ x0 for all possible x0, we simply need to return the matrix A. That’s why in practice x0 never appears. On 28 Jun 2023, at 15:00, Codema @.**> wrote:

Hi, thanks to your reply. but why "If we want to compute the value of A @ x0 for all possible x0, we simply need to return the matrix A."? you referred that: "If A is a matrix, A @ x0 is a specific line of A." A@x0 exchange A of raw, so why we can return original A? could you reply in detail? thanks.