gle-bellier / discrete-fm

Educational implementation of the Discrete Flow Matching paper
24 stars 2 forks source link

How do you set the mask_prob in CCoupling? why it's 0.8? #1

Open dongzhuoyao opened 1 week ago

dongzhuoyao commented 1 week ago

class Ccoupling(Coupling): def init(self, msk_prop: float = 0.8) -> None: self.msk_prob = msk_prop

def sample(self, x1: Img) -> tuple[Img, Img]:
    # sample mask
    I = torch.rand_like(x1.float()) > self.msk_prob
    x0 = x1 * I + torch.zeros_like(x1) * (~I)
    return x0, x1
gle-bellier commented 1 week ago

Thanks for your feedback on this notebook. This is a re-interpretation of what can be $p_0$ for images for certain applications, here it's the distribution of 80%-masked images. It was more an illustrative example since the rest of the notebook does not handle the Ccoupling case (it necessitates learning both noise-pred and denoiser if you want to use forward and backward velocities, e.g. for corrector sampling) More details about the Ccoupling for text are included in Appendix B (page 15), which shows how to sample the proportion of the data to mask.