atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.25k stars 101 forks source link

sigma type & doc #72

Closed guillaumehu closed 1 year ago

guillaumehu commented 1 year ago

Adding docstrings for the __init__ of OTPlanSampler, and fixing issue #56 by checking the type of sigma during initialization.

kilianFatras commented 1 year ago

Hello Guillaume!

Thanks for the PR. I think the sigma type should be checked for all classes... That includes OT-CFM, SB-CFM, FM, Stochastic Interpolants. @josephdviviano what do you think? We want to ensure that sigma is a float as when it is an int, we get a bug.

atong01 commented 1 year ago

I think we want to make sure sigma and t have the same type as x at runtime and not type check on init. I think if someone has x which is a float64 tensor right now we might also have problems. Good place to put a test @kilianFatras 😆.

guillaumehu commented 1 year ago

@atong01 I just tested it, and it is working with float64 even if sigma is initialized as int, float or tensor (32 or 64). 't' has the right dtype since you already specify t = torch.rand(x0.shape[0]).type_as(x0).

kilianFatras commented 1 year ago

@atong01 Maybe a better solution would be to make the type of sigma_t to match the type of x within the compute_xt function?

guillaumehu commented 1 year ago

But they already have the same type, since torch works with mixed precision. The problem is here on L34. https://github.com/atong01/conditional-flow-matching/blob/21cd0c888186f6e2b76deb393800361b8a850e9b/torchcfm/conditional_flow_matching.py#L34-L36

A simple fix is replacing L34 with if isinstance(t, (float, int)):, this works without checking the type of sigma in the initialization.

josephdviviano commented 1 year ago

I think I don't fully understand how the code works. Just a moment. But initial thoughts:

torch works with mixed precision.

kilianFatras commented 1 year ago

I let @josephdviviano decide when this PR is ready as he has more knowledge than me on how to fix a variable's type.

@guillaumehu once this PR is ready, please prepare a test to add to the add_tests branch (or to this PR but I would prefer all tests within the same PR). I have made tests for all classes within torchcfm and we need a test on sigma's type as well.