HannesStark / dirichlet-flow-matching

MIT License
75 stars 12 forks source link

A couple of questions #1

Closed erl-j closed 4 months ago

erl-j commented 4 months ago

Hello!

First of all thank you for sharing this exciting work.

I have a couple of questions:

First, the following function is called during both training and inference:

def expand_simplex(xt, alphas, prior_pseudocount):
    prior_weights = (prior_pseudocount / (alphas + prior_pseudocount - 1))[:, None, None]
    return torch.cat([xt * (1 - prior_weights), xt * prior_weights], -1), prior_weights

However, I am unable to relate this to the algorithm described in the paper. What is the purpose of this operation?

Secondly, there are some hyperparameters that I'm not sure how to set and which are not discussed in the paper AFAIK. Would you mind sharing some intuitions on how to set these as a function of vocabulary size (K) and sequence length (T)?

    parser.add_argument("--simplex_spacing", type=int, default=1000)
    parser.add_argument("--prior_pseudocount", type=float, default=2)
    parser.add_argument("--alpha_scale", type=float, default=2)
    parser.add_argument("--alpha_max", type=float, default=8)

I understand that this could be a difficult question, but any hints would be of great help :)

HannesStark commented 4 months ago

Hi!

Thanks for asking. The expand_simplex function is an additional approach for encoding the current flow matching time step (which in our case is represented by alpha). The function turns a K dimensional point on the simplex into a 2*K representation that encodes the current time along the probability path. We tried this in earlier experiments - likely it does not impact performance significantly and you can remove it (we explicitly provide the flow time to the model).

The parameters: They should be independent of sequence length. We have not compared the parameters under different vocabulary sizes but also found Dirichlet FM to work fine for K=20. However, we think that it is likely that simplex diffusion or flow matching approaches, including Dirichlet FM, will struggle with large vocabulary sizes such as 50,000.

simplex_spacing: This is deprecated and no longer used anymore. It used to control the alpha_spacing parameter of the DirichletConditionalFlow class. I made a commit to add this to the parameter description. To compute the magnitude of the conditional vector field (the C of equation 16), we need to evaluate the derivative of the regularized incomplete beta function. To do so, we discretize the regularized incomplete beta function and evaluate it at different alphas (see https://github.com/HannesStark/dirichlet-flow-matching/blob/0c7b80d1e96472571a72dd45747ddbb2cb95c8b8/utils/flow_utils.py#L128C49-L128C56). The simplex spacing is the number of alphas this is discretized into. A higher number here will lead to a finer discretization.

prior_pseudocount: hyperparameter for expand_simplex function. Can be kept at 2 and should not matter much.

alpha_scale: controls the alphas at which training is performed. If this is higher, then higher alphas are sampled for training. Recall that alphas correspond to diffusion time in dirichlet flow matching. alpha=1 corresponds to full noise, and alpha_max, such as alpha_max = 8 corresponds to clean data.

alpha_max: controls the maximum value until which we run the inference process. In equation 14, we write our probability path to go from alpha=1 to alpha=infinity. In practice we cut off alpha at alpha_max.

image
erl-j commented 4 months ago

Thank you so much, this was really helpful.

I'm curious about why you are not optimistic about simplex methods scaling to large vocabularies.

SSD-LM uses a 50K vocabulary size, although they start from a pre-trained RoBERTa checkpoint.

HannesStark commented 4 months ago

Nice to hear that it helped!

We came across this paper. They do not actually use a diffusion process on the simplex. They just fully relax to euclidean space and do continuous diffusion.

erl-j commented 4 months ago

Regardless, I'm curious why you don't believe simplex diffusion will work for large vocabularies :)