andrew-cr / discrete_flow_models

Code for the paper https://arxiv.org/abs/2402.04997
MIT License
45 stars 0 forks source link

Parameterization of the reverse process #2

Open Abc11c opened 4 months ago

Abc11c commented 4 months ago

Hi @andrew-cr ,

Thanks for the code release,

I was wondering if you adapt the x_0 parameterization as seen in D3PM paper ? If so could you point me to the line in the code.

Thanks!

andrew-cr commented 4 months ago

Hi! Thanks for the interest in the work.

Yes we do parametrize using clean data prediction like D3PM. In our case we have t=1 being clean data so its all done in terms of x_1 prediction. In our paper https://arxiv.org/pdf/2402.04997 in Algorithm 1 it is line 3 where the rate we use to simulate is made out of an expectation of p(x_1 | x_t) which is our clean data prediction neural network. In the code it is in the following lines:

In the notebooks it is x1_probs = F.softmax(logits, dim=-1) # (B, D, S) which gets clean data probabilities during sampling.

For the main code sample.py it is https://github.com/andrew-cr/discrete_flow_models/blob/800395d172be6b950d2ab87bcf154d752bd2cf76/sample.py#L205 where we get probabilities over clean data during sampling.

Abc11c commented 4 months ago

Thanks @andrew-cr ,

  1. Just to confirm, we use the rate matrix when (t<1.0) and then just use argmax (t=1.0 pure data ?) now this is analogous to way D3PM works at training but with the flow models, we can ignore this parameterization at training and use the rate matrix during sampling only ?

    D3PM loss: At the first timestep (here t = 1.0) return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))

    Flow loss: Directly use cross-entropy (direct parameterization)

  2. A dumb question what is the purpose of argmax here, should there also be a argmax at the end of loop in toycode_general.ipynb

Thanks!

andrew-cr commented 4 months ago
  1. Yes that is correct, we can ignore the rate matrix formulation during training and we only look to train the x_1 predictor using the cross-entropy loss. Then at training for steps t<1.0 we convert it into a rate matrix that we use for sampling. This is indeed different to the D3PM loss that uses the KL you have written.
  2. This is a slight inconsistency on my part. In the main code, the argmax is because we only sample to max_t = 0.98 which I did to avoid any possibilities of singularities near t=1.0 which can happen due to denominators of the form 1/(1-t) in the rate matrix. If we only sample to t=0.98 then some tokens will still be mask so we need to fill in those remaining positions. I could have sampled from the probabilities to get these infilling tokens but I chose to just use the argmax value to take the highest probability token for these small number of final tokens. I don't expect it to make much difference if you use the argmax or sample them properly. For toycode_general.ipynb I was more careful about singularities and found that for the final step of the while loop just sampling using the rate matrix properly results in the effect of all remaining masked positions getting infilled using a sample from p(x_1 | x_t) so there won't be the case that there are left over tokens that need to be infilled hence we don't need the argmax. Also in the toycode_general.ipynb case, we may not even use the masking process, we could use another style of process that does not use mask tokens. If there are no mask tokens then you wouldn't even know which tokens need infilling on the last step so we can't use an argmax style final step.