HannesStark / dirichlet-flow-matching

MIT License
85 stars 12 forks source link

Conditional probability path's t value #4

Closed MakssMeiers closed 5 months ago

MakssMeiers commented 5 months ago

Description We are working on a generative model, and want to check Your approach, using Dirichlet - FM. However, we've found some issues between your research paper and a github project, considering the probability path. Research paper On the page 4 in the paragraph 3.3, you've defined a probability path as a p(x|x1=ei) = Dir(x; α = 1 + t · ei) with t ∈ [0, ∞), where ei, as it was said in the paragraph 3.1 on page 3, is ith one-hot encoded vector of x. But later, in the pseudocode from part B, you've defined that t value ∼ U[0, 1] and probability path is the same pt(xt | x = ei) = Dir(x; α = 1 + t · ei). It was also mentioned in the paragraph 2.1 page 2 that “we define a conditional probability path—a time-evolving distribution pt(x | x1), t ∈ [0, 1]”. Github project However, in the github file dirichlet-flow-matching/utils/flow-utils inside the function sample_cond_prob_path your alpha parameter is alphas = torch.from_numpy(1 + scipy.stats.expon().rvs(size=B) args.alpha_scale).to(seq.device).float() if args.fix_alpha: alphas = torch.ones(B, device=seq.device) args.fixalpha alphas = torch.ones(B, L, alphabetsize, device=seq.device) alphas = alphas_ + seq_one_hot (alphas[:,None,None] – 1) (lines 92-96) Which means, that, if fix_alpha parameter is False, then the alphas are equal to: 1 + ei (1 + samples_from_exp_distribution - 1) = 1 + ei * samples_from_exp_distribution. But samples from the exponential distribution do not lay in the interval [0, 1] and it was also never mentioned that neither the t nor ei are exponentially distributed.

HannesStark commented 5 months ago

Hi, thanks for spotting this, as we also did in our camera-ready preparations. The sampling of the training times in the training algorithm in the appendix should indeed be $t \sim Exp(1)$ (likely $t \sim U[0, t_{max}]$ would work just as well). We corrected this in the camera ready version, which we will also upload to arxiv soon:

image

Please let me know if there are any remaining questions.

MakssMeiers commented 5 months ago

Thank you for the response Hannes! There seems to be one more question. Based on the both equations 15, 16 and the github code of the vector field, shouldn't be inside a pseudocode's inference algorithm a minus sign before the I`x(a, b) (derivative of the regularized incomplete beta function)?

HannesStark commented 5 months ago

Good catch thanks!