Closed WANGSSSSSSS closed 6 months ago
I use the pretrained SiT-XL, eular-ode with cfg=1.5, w=sigma and steps=250 achieves 2.45. eular-sde with cfg=1.5, w=sigma and steps=250, the last step=0.004 achieves 2.53(sqrt{wdt}) As the computation is really slooow, i use fp16-mixed as the default setting
Hi,
FID is well-known to be sensitive to many small implementation choices, and using fp16b might be one of them, since the metric reported in our paper is produced with fp32.
The degradation of FID when using larger cfg scale is also well-known, since ImageNet contains many low-quality images and calculating with high visual quality samples will lead to a deviation in distribution with ImageNet. Hope this provides some insights.
Best
Hi,
FID is well-known to be sensitive to many small implementation choices, and using fp16b might be one of them, since the metric reported in our paper is produced with fp32.
The degradation of FID when using larger cfg scale is also well-known, since ImageNet contains many low-quality images and calculating with high visual quality samples will lead to a deviation in distribution with ImageNet. Hope this provides some insights.
Best
Thank you for your quick reply. At least i made some progress, i find the fid metric of imagenet256 is computed on the 256-size image, thus a image needs to be resized to 256 firstly, then resize to 299 secondly. After this correction, i produce 2.149 for ode and 2.34((x + vdt + s * w dt + torch.sqrt(2wdt)torch.randn_like(x))) for sde variant. Talking these mess really wastes your time, but i really need your help to align sde part. So would you mind to review my eular-sde solver code, checking the correction of my solver. 😭
Hi.
Hi, FID is well-known to be sensitive to many small implementation choices, and using fp16b might be one of them, since the metric reported in our paper is produced with fp32. The degradation of FID when using larger cfg scale is also well-known, since ImageNet contains many low-quality images and calculating with high visual quality samples will lead to a deviation in distribution with ImageNet. Hope this provides some insights. Best
Thank you for your quick reply. At least i made some progress, i find the fid metric of imagenet256 is computed on the 256-size image, thus a image needs to be resized to 256 firstly, then resize to 299 secondly. After this correction, i produce 2.149 for ode and 2.34((x + vdt + s * w dt + torch.sqrt(2wdt)torch.randn_like(x))) for sde variant. Talking these mess really wastes your time, but i really need your help to align sde part. So would you mind to review my eular-sde solver code, checking the correction of my solver. 😭
Hi, what's going on, same results here. Far from the paper's results.
Hi.
Hi, FID is well-known to be sensitive to many small implementation choices, and using fp16b might be one of them, since the metric reported in our paper is produced with fp32. The degradation of FID when using larger cfg scale is also well-known, since ImageNet contains many low-quality images and calculating with high visual quality samples will lead to a deviation in distribution with ImageNet. Hope this provides some insights. Best
Thank you for your quick reply. At least i made some progress, i find the fid metric of imagenet256 is computed on the 256-size image, thus a image needs to be resized to 256 firstly, then resize to 299 secondly. After this correction, i produce 2.149 for ode and 2.34((x + vdt + s * w dt + torch.sqrt(2wdt)torch.randn_like(x))) for sde variant. Talking these mess really wastes your time, but i really need your help to align sde part. So would you mind to review my eular-sde solver code, checking the correction of my solver. 😭
Hi, what's going on, same results here. Far from the paper's results.
sde still not be aligned with the reported metrics, but ode does
Hello, author. The theory and writing of SiT is really charming, I am interested in this work, and i want to do some improvement on SiT, but I am struggle with reproducing the metrics reported in the SIT paper. I notice some techniques important to get better fid, but not good enough to match 2.15 and 2.09.
def ode_step_fn(x, v, dt, s, w): return x + v * dt
def sde_mean_step_fn(x, v, dt, s, w): return x + v dt + s w * dt
def sde_step_fn(x, v, dt, s, w): return x + vdt + s w dt + torch.sqrt(2wdt)torch.randn_like(x)
def sde_preserve_step_fn(x, v, dt, s, w): return x + vdt + 0.5s w dt + torch.sqrt(wdt)torch.randn_like(x)
import logging logger = logging.getLogger(name)
class FlowMatchEulerSampler(BaseSampler): def init( self, num_steps: int = 250, guidance=4.0, scheduler: BaseScheduler = None, w_scheduler: BaseScheduler = None, guidance_fn: Callable = c3_guidance_fn, step_fn: Callable = ode_step_fn, last_step=0.0, last_step_fn: Callable = ode_step_fn, pred_eps = False, *args, *kwargs ): super().init(args, **kwargs) self.scheduler = scheduler self.num_steps = num_steps self.pred_eps = pred_eps self.guidance = guidance self.guidance_fn = guidance_fn self.step_fn = step_fn self.last_step = last_step self.last_step_fn = last_step_fn self.w_scheduler = w_scheduler
class FlowMatchHenuSampler(BaseSampler): def init( self, num_steps: int = 250, guidance=4.0, scheduler: BaseScheduler = None, w_scheduler: BaseScheduler = None, guidance_fn: Callable = c3_guidance_fn, pred_eps=False, exact_henu=False, step_fn: Callable = ode_step_fn, last_step=0.04, last_step_fn: Callable = ode_step_fn, *args, *kwargs ): super().init(args, **kwargs) self.scheduler = scheduler self.num_steps = num_steps self.pred_eps = pred_eps self.guidance = guidance self.guidance_fn = guidance_fn self.exact_henu = exact_henu self.step_fn = step_fn self.last_step = last_step self.last_step_fn = last_step_fn self.w_scheduler = w_scheduler