willisma / SiT

Official PyTorch Implementation of "SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers"
https://scalable-interpolant.github.io/
MIT License
662 stars 35 forks source link

sample_eps/train_eps are not passed to Transport in create_transport #1

Closed Daiver closed 9 months ago

Daiver commented 9 months ago

Hi! I've just lurked around code and found that in function create_transport there are variables train_eps and sample_eps which is defined but not used. Here is the original code

    if (path_type in [PathType.VP]):
        train_eps = 1e-5 if train_eps is None else train_eps
        sample_eps = 1e-3 if train_eps is None else sample_eps
    elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
        train_eps = 1e-3 if train_eps is None else train_eps
        sample_eps = 1e-3 if train_eps is None else sample_eps
    else: # velocity & [GVP, LINEAR] is stable everywhere
        train_eps = 0
        sample_eps = 0

    # create flow state
    state = Transport(
        model_type=model_type,
        path_type=path_type,
        loss_type=loss_type,
        # I suppose that it should be 
        # train_eps=train_eps,
        # sample_eps=train_eps,
    )

    return state

Not sure is it important, just decided that you can be interested in fixing it.

willisma commented 9 months ago

Hi Davier,

Thanks for spotting the issue! The code is updated accordingly.