ML-GSAI / BFN-Solver

Official PyTorch implementation for "Unifying Bayesian Flow Networks and Diffusion Models through Stochastic Differential Equations"
32 stars 2 forks source link

the differnece between the code and pseudocode #2

Open explorer1212 opened 1 month ago

explorer1212 commented 1 month ago

Dear authors, Very excellent work! When I read the implementation code, I have found the code is slightly different from pseudocode. Take bfnsolver++1 as an example. Does the x_t in code is the variable "μ_i" in pseudocode? Then does the noise_pred is "\hat{x_i}"? If so, in the pseucode, why is there an equation between \hat{x_i} and epsilon?

def ode_bfnsolver1_update(self, x_s, step, last_drop=False):
    # x_s -> x_t
    t = torch.ones_like(x_s, device=x_s.device) * (1 - self.times[step])
    # noise predict and x0 predict
    with torch.no_grad():
        noise_pred = self.unet(x_s, t).reshape(x_s.shape)
    alpha_t, sigma_t = self.alpha_t[step], self.sigma_t[step]
    x0_pred = (x_s - sigma_t * noise_pred) / alpha_t

    # clip x0
    x0_pred = x0_pred.clip(min=-1.0, max=1.0)
    noise_pred = (x_s - x0_pred * alpha_t) / sigma_t

    # get schedule
    lambda_t, lambda_s = self.lambda_t[step + 1], self.lambda_t[step]
    alpha_t, alpha_s = self.alpha_t[step + 1], self.alpha_t[step]
    sigma_t, sigma_s = self.sigma_t[step + 1], self.sigma_t[step]
    h = lambda_t - lambda_s

    if last_drop == True and step == self.num_steps - 1:
        return x0_pred, x0_pred
    else:
        x_t = (alpha_t / alpha_s) * x_s - (sigma_t * (torch.exp(h) - 1.0)) * noise_pred

    return x_t, x0_pred

image

explorer1212 commented 1 month ago

By the way, since many models predict the x directly, Is there any way to use x directly, instead of the noise?

kaiwenxue0 commented 1 month ago

Sorry for the late reply and for any confusion caused by the code.

To clarify:

You can verify this by substituting the expressions for "\bar{alpha}_t" and "\bar{sigma}_t", along with the "\hat{\epsilon}" terms, into the latest iteration equation within the for loop of the pseudocode. This substitution aligns the code with the intended mathematical formulation.

I apologize again for the misleading code. We will update it promptly to ensure consistency with the pseudocode.

Thank you for pointing this out!

explorer1212 commented 3 weeks ago

Thanks for your reply! By the way, I have another question. The sampling methods based on SDE seem to lack diversity, especially when applied to molecular generation. Is there any good solution or reference? Thank you very much!