Closed ImahnShekhzadeh closed 10 months ago
Thank you for your interest.
The UNet model has the input argument "num_classes" set by default to None (https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/unet.py#L410). When num_classes is not none, it is used to create a label embedding (see https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/unet.py#L448).
To call it when we want to use conditional generation, we use the wrapper with two parameters (class_cond and num_classes) (see https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/unet.py#L913).
I hope this helps.
Ah, I see what you mean. We have a init file in the Unet folder https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/__init__.py that calls:
from .unet import UNetModelWrapper as UNetModel
so we are always using the UNetWrapper.
Thanks a lot, this is very helpful!
Hi,
First of all thanks for this amazing repository!
In
conditional_mnist.ipynb
, theUNetModel
is called like this,The above code works fine, but I do not fully understand why, since the
UNetModel
itself does not have these arguments, but instead the wrapperUNetModelWrapper