atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.27k stars 103 forks source link

Confusion about the TargetConditionalFlowMatcher #151

Closed tkun-li closed 5 days ago

tkun-li commented 5 days ago

Hi,

Thanks for releasing this repo for us to learn flow matching! While I was looking at the TargetConditionalFlowMathcer in conditional_flow_matching.py, I found that in the sample_xt() function, the output is something like mu_t + sigma_t * torch.randn_like(x0). In the paper FLOW MATCHING FOR GENERATIVE MODELING, the final loss function of example two is written as follows (Eq. 23): image My confusion is shouldn't the output of sample_xt() be something like mu_t + sigma_t * x0 instead of mu_t + sigma_t * torch.randn_like(x0) as in Eq. 22? In the above loss function, the x0 is used instead of sampling a new random noise. Is there someting I miss? Or is these equivalent?

Sorry for the inconvenience. I am new to the field. Many thanks!

atong01 commented 5 days ago

Hi,

My intuition says these should be equivalent because it's the sum of two normals, but I haven't rigorously checked.

tkun-li commented 5 days ago

Ok. Thanks for the quick reply!