atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.25k stars 101 forks source link

Question about `UNetModel` used in `conditional_mnist.ipynb` #96

Closed ImahnShekhzadeh closed 10 months ago

ImahnShekhzadeh commented 10 months ago

Hi,

First of all thanks for this amazing repository!

In conditional_mnist.ipynb, the UNetModel is called like this,

model = UNetModel(
    dim=(1, 28, 28), num_channels=32, num_res_blocks=1, num_classes=10, class_cond=True
).to(device)

The above code works fine, but I do not fully understand why, since the UNetModel itself does not have these arguments, but instead the wrapper UNetModelWrapper

kilianFatras commented 10 months ago

Thank you for your interest.

The UNet model has the input argument "num_classes" set by default to None (https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/unet.py#L410). When num_classes is not none, it is used to create a label embedding (see https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/unet.py#L448).

To call it when we want to use conditional generation, we use the wrapper with two parameters (class_cond and num_classes) (see https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/unet.py#L913).

I hope this helps.

kilianFatras commented 10 months ago

Ah, I see what you mean. We have a init file in the Unet folder https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/__init__.py that calls:

from .unet import UNetModelWrapper as UNetModel

so we are always using the UNetWrapper.

ImahnShekhzadeh commented 10 months ago

Thanks a lot, this is very helpful!