Closed guillaumehu closed 1 year ago
Hello Guillaume!
Thanks for the PR. I think the sigma type should be checked for all classes... That includes OT-CFM, SB-CFM, FM, Stochastic Interpolants. @josephdviviano what do you think? We want to ensure that sigma is a float as when it is an int, we get a bug.
I think we want to make sure sigma
and t
have the same type as x
at runtime and not type check on init. I think if someone has x which is a float64 tensor right now we might also have problems. Good place to put a test @kilianFatras 😆.
@atong01 I just tested it, and it is working with float64 even if sigma is initialized as int, float or tensor (32 or 64).
't' has the right dtype since you already specify t = torch.rand(x0.shape[0]).type_as(x0)
.
@atong01 Maybe a better solution would be to make the type of sigma_t to match the type of x within the compute_xt
function?
But they already have the same type, since torch works with mixed precision. The problem is here on L34. https://github.com/atong01/conditional-flow-matching/blob/21cd0c888186f6e2b76deb393800361b8a850e9b/torchcfm/conditional_flow_matching.py#L34-L36
A simple fix is replacing L34 with if isinstance(t, (float, int)):
, this works without checking the type of sigma in the initialization.
I think I don't fully understand how the code works. Just a moment. But initial thoughts:
torch works with mixed precision.
int
and float
, which should always return a float
. This might be confusing for the user, if they submit a int
and expect another int
in return. It's fine if things need to be this way, but I think it would be good for it to be consistent and documented (either types are always conserved or they're always cast to float). Sorry in advance if I misunderstand the code and one of those two conditions already holds.@property
should not do any type casting, because this variable self.sigma
should always equal self._sigma
, it leaves the door open for a very confusing user experience if they're grappling with some strange type error.I let @josephdviviano decide when this PR is ready as he has more knowledge than me on how to fix a variable's type.
@guillaumehu once this PR is ready, please prepare a test to add to the add_tests branch (or to this PR but I would prefer all tests within the same PR). I have made tests for all classes within torchcfm and we need a test on sigma's type as well.
Adding docstrings for the
__init__
ofOTPlanSampler
, and fixing issue #56 by checking the type of sigma during initialization.