willisma / SiT

Official PyTorch Implementation of "SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers"
https://scalable-interpolant.github.io/
MIT License
662 stars 35 forks source link

The reproduction related #14

Closed WANGSSSSSSS closed 6 months ago

WANGSSSSSSS commented 6 months ago

Hello, author. The theory and writing of SiT is really charming, I am interested in this work, and i want to do some improvement on SiT, but I am struggle with reproducing the metrics reported in the SIT paper. I notice some techniques important to get better fid, but not good enough to match 2.15 and 2.09.

  1. i find the diffusion term in paper use sqrt{wdt} ( x + vdt + 0.5s w dt + torch.sqrt(wdt)torch.randn_like(x)) while this codebase employs sqrt{2wdt}(x + vdt + s w dt + torch.sqrt(2wdt)torch.randn_like(x)). in my experiment, sqrt{wdt} yields 2.53 fid while sqrt{2wdt} achieves 2.65.
  2. use cfg1.5 and only apply cfg on the first three channels. cfg4.0 achieves better vis but fid metric is much worse. Here is my ode and sde solver, would you mind give me some suggestion?
    
    def c3_guidance_fn(out, cfg):
    uncondition, condtion = out.chunk(2, dim=0)
    out = condtion
    out[:, :3] = uncondition[:, :3] + cfg * (condtion[:, :3] - uncondition[:, :3])
    return out

def ode_step_fn(x, v, dt, s, w): return x + v * dt

def sde_mean_step_fn(x, v, dt, s, w): return x + v dt + s w * dt

def sde_step_fn(x, v, dt, s, w): return x + vdt + s w dt + torch.sqrt(2wdt)torch.randn_like(x)

def sde_preserve_step_fn(x, v, dt, s, w): return x + vdt + 0.5s w dt + torch.sqrt(wdt)torch.randn_like(x)

import logging logger = logging.getLogger(name)

class FlowMatchEulerSampler(BaseSampler): def init( self, num_steps: int = 250, guidance=4.0, scheduler: BaseScheduler = None, w_scheduler: BaseScheduler = None, guidance_fn: Callable = c3_guidance_fn, step_fn: Callable = ode_step_fn, last_step=0.0, last_step_fn: Callable = ode_step_fn, pred_eps = False, *args, *kwargs ): super().init(args, **kwargs) self.scheduler = scheduler self.num_steps = num_steps self.pred_eps = pred_eps self.guidance = guidance self.guidance_fn = guidance_fn self.step_fn = step_fn self.last_step = last_step self.last_step_fn = last_step_fn self.w_scheduler = w_scheduler

    if self.last_step == 0.0:
        self.last_step = 1.0 / self.num_steps

    assert self.scheduler is not None
    assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
    if self.w_scheduler is not None:
        if self.step_fn == ode_step_fn:
            logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")

def _impl_sampling(self, net, images, labels):
    """
    sampling process of Euler sampler
    -
    """
    batch_size = images.shape[0]
    steps = torch.linspace(0.0, 1 - self.last_step, self.num_steps, device=images.device)
    steps = torch.cat([steps, torch.tensor([1.0], device=images.device)], dim=0)

    null_labels = torch.full_like(labels, self.null_class)
    labels = torch.cat([null_labels, labels], dim=0)
    x = images
    dt = steps[1] - steps[0]
    for i, t_cur in enumerate(steps[:-1]):
        t_cur = t_cur.repeat(batch_size)
        sigma = self.scheduler.sigma(t_cur)
        drift_coeff = self.scheduler.drift_coefficient(t_cur)
        diffusion_coeff = self.scheduler.diffuse_coefficient(t_cur)
        dalpha_over_alpha = self.scheduler.dalpha_over_alpha(t_cur)
        dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)
        if self.w_scheduler:
            w = self.w_scheduler.w(t_cur)
        else:
            w = 0.0

        cfg_x = torch.cat([x, x], dim=0)
        t = t_cur.repeat(2)
        out = net(cfg_x, t, labels)
        out = self.guidance_fn(out, self.guidance)

        if self.pred_eps:
            s = out / sigma
            v = drift_coeff*x + diffusion_coeff*s
        else:
            v = out
            s = ((1/dalpha_over_alpha)*v - x)/(sigma**2 - (1/dalpha_over_alpha)*dsigma_mul_sigma)
        if i < self.num_steps -1 :
            x = self.step_fn(x, v, dt, s=s, w=w)
        else:
            x = self.last_step_fn(x, v, self.last_step, s=s, w=w)
    return x

class FlowMatchHenuSampler(BaseSampler): def init( self, num_steps: int = 250, guidance=4.0, scheduler: BaseScheduler = None, w_scheduler: BaseScheduler = None, guidance_fn: Callable = c3_guidance_fn, pred_eps=False, exact_henu=False, step_fn: Callable = ode_step_fn, last_step=0.04, last_step_fn: Callable = ode_step_fn, *args, *kwargs ): super().init(args, **kwargs) self.scheduler = scheduler self.num_steps = num_steps self.pred_eps = pred_eps self.guidance = guidance self.guidance_fn = guidance_fn self.exact_henu = exact_henu self.step_fn = step_fn self.last_step = last_step self.last_step_fn = last_step_fn self.w_scheduler = w_scheduler

    if self.last_step == 0.0:
        self.last_step = 1.0 / self.num_steps

    assert self.scheduler is not None
    assert self.w_scheduler is not None or self.step_fn in [ode_step_fn, ]
    if self.w_scheduler is not None:
        if self.step_fn == ode_step_fn:
            logger.warning("current sampler is ODE sampler, but w_scheduler is enabled")

def _impl_sampling(self, net, images, labels):
    """
    sampling process of Henu sampler
    -
    """
    batch_size = images.shape[0]
    steps = torch.linspace(0.0, 1 - self.last_step, self.num_steps, device=images.device)
    steps = torch.cat([steps, torch.tensor([1.0], device=images.device)], dim=0)
    null_labels = torch.full_like(labels, self.null_class)
    labels = torch.cat([null_labels, labels], dim=0)
    x = images
    v_hat, s_hat = 0.0, 0.0
    dt = steps[1] - steps[0]
    for i, (t_cur, t_next) in enumerate(zip(steps[:-1], steps[1:])):

        t_cur = t_cur.repeat(batch_size)
        sigma = self.scheduler.sigma(t_cur)
        drift_coeff = self.scheduler.drift_coefficient(t_cur)
        diffusion_coeff = self.scheduler.diffuse_coefficient(t_cur)
        alpha_over_dalpha = 1/self.scheduler.dalpha_over_alpha(t_cur)
        dsigma_mul_sigma = self.scheduler.dsigma_mul_sigma(t_cur)

        t_hat = t_next
        t_hat = t_hat.repeat(batch_size)
        sigma_hat = self.scheduler.sigma(t_hat)
        drift_coeff_hat = self.scheduler.drift_coefficient(t_hat)
        diffusion_coeff_hat = self.scheduler.diffuse_coefficient(t_hat)
        alpha_over_dalpha_hat = 1 / self.scheduler.dalpha_over_alpha(t_hat)
        dsigma_mul_sigma_hat = self.scheduler.dsigma_mul_sigma(t_hat)

        if self.w_scheduler:
            w = self.w_scheduler.w(t_cur)
        else:
            w = 0.0
        if i == 0 or self.exact_henu:
            cfg_x = torch.cat([x, x], dim=0)
            t_cur = t_cur.repeat(2)
            out = net(cfg_x, t_cur, labels)
            out = self.guidance_fn(out, self.guidance)

            if self.pred_eps:
                s = out / sigma
                v = drift_coeff * x + diffusion_coeff * s
            else:
                v = out
                s = ((alpha_over_dalpha)*v - x)/(sigma**2 - (alpha_over_dalpha)*dsigma_mul_sigma)
        else:
            v = v_hat
            s = s_hat

        if i < self.num_steps -1:
            # henu correct
            x_hat = self.step_fn(x, v, dt, s=s, w=w)
            cfg_x_hat = torch.cat([x_hat, x_hat], dim=0)
            t_hat = t_hat.repeat(2)
            out = net(cfg_x_hat, t_hat, labels)
            out = self.guidance_fn(out, self.guidance)
            if self.pred_eps:
                s_hat = out / sigma_hat
                v_hat = drift_coeff_hat * x_hat + diffusion_coeff_hat * s_hat
            else:
                v_hat = out
                s_hat = ((alpha_over_dalpha_hat)* v_hat - x_hat) / (sigma_hat ** 2 - (alpha_over_dalpha_hat) * dsigma_mul_sigma_hat)
            v = (v + v_hat) / 2
            s = (s + s_hat) / 2
            x = self.step_fn(x, v, dt, s=s, w=w)
        else:
            x = self.last_step_fn(x, v, self.last_step, s=s, w=w)
    return x
WANGSSSSSSS commented 6 months ago

I use the pretrained SiT-XL, eular-ode with cfg=1.5, w=sigma and steps=250 achieves 2.45. eular-sde with cfg=1.5, w=sigma and steps=250, the last step=0.004 achieves 2.53(sqrt{wdt}) As the computation is really slooow, i use fp16-mixed as the default setting

willisma commented 6 months ago

Hi,

FID is well-known to be sensitive to many small implementation choices, and using fp16b might be one of them, since the metric reported in our paper is produced with fp32.

The degradation of FID when using larger cfg scale is also well-known, since ImageNet contains many low-quality images and calculating with high visual quality samples will lead to a deviation in distribution with ImageNet. Hope this provides some insights.

Best

WANGSSSSSSS commented 6 months ago

Hi,

FID is well-known to be sensitive to many small implementation choices, and using fp16b might be one of them, since the metric reported in our paper is produced with fp32.

The degradation of FID when using larger cfg scale is also well-known, since ImageNet contains many low-quality images and calculating with high visual quality samples will lead to a deviation in distribution with ImageNet. Hope this provides some insights.

Best

Thank you for your quick reply. At least i made some progress, i find the fid metric of imagenet256 is computed on the 256-size image, thus a image needs to be resized to 256 firstly, then resize to 299 secondly. After this correction, i produce 2.149 for ode and 2.34((x + vdt + s * w dt + torch.sqrt(2wdt)torch.randn_like(x))) for sde variant. Talking these mess really wastes your time, but i really need your help to align sde part. So would you mind to review my eular-sde solver code, checking the correction of my solver. 😭

xiao2mo commented 5 months ago

Hi.

Hi, FID is well-known to be sensitive to many small implementation choices, and using fp16b might be one of them, since the metric reported in our paper is produced with fp32. The degradation of FID when using larger cfg scale is also well-known, since ImageNet contains many low-quality images and calculating with high visual quality samples will lead to a deviation in distribution with ImageNet. Hope this provides some insights. Best

Thank you for your quick reply. At least i made some progress, i find the fid metric of imagenet256 is computed on the 256-size image, thus a image needs to be resized to 256 firstly, then resize to 299 secondly. After this correction, i produce 2.149 for ode and 2.34((x + vdt + s * w dt + torch.sqrt(2wdt)torch.randn_like(x))) for sde variant. Talking these mess really wastes your time, but i really need your help to align sde part. So would you mind to review my eular-sde solver code, checking the correction of my solver. 😭

Hi, what's going on, same results here. Far from the paper's results.

WANGSSSSSSS commented 5 months ago

Hi.

Hi, FID is well-known to be sensitive to many small implementation choices, and using fp16b might be one of them, since the metric reported in our paper is produced with fp32. The degradation of FID when using larger cfg scale is also well-known, since ImageNet contains many low-quality images and calculating with high visual quality samples will lead to a deviation in distribution with ImageNet. Hope this provides some insights. Best

Thank you for your quick reply. At least i made some progress, i find the fid metric of imagenet256 is computed on the 256-size image, thus a image needs to be resized to 256 firstly, then resize to 299 secondly. After this correction, i produce 2.149 for ode and 2.34((x + vdt + s * w dt + torch.sqrt(2wdt)torch.randn_like(x))) for sde variant. Talking these mess really wastes your time, but i really need your help to align sde part. So would you mind to review my eular-sde solver code, checking the correction of my solver. 😭

Hi, what's going on, same results here. Far from the paper's results.

sde still not be aligned with the reported metrics, but ode does