CompVis / zigma

A PyTorch implementation of the paper "ZigMa: A DiT-Style Mamba-based Diffusion Model" (ECCV 2024)
https://taohu.me/zigma
Apache License 2.0
281 stars 19 forks source link

Is training stable using mamba? #3

Closed lqniunjunlper closed 7 months ago

lqniunjunlper commented 8 months ago

Mamba sometimes has training problem with nan loss when using fp16 or amp.

So the curiosity here is how to keep the traing process stable while using fp16 for effiency.

Thanks.

dongzhuoyao commented 8 months ago

Hi, from my experience. It's quite stable even in single-gpu(you need open the rescaling function of accelerator to avoid the NAN loss issue.)