tianweiy / DMD2

(NeurIPS 2024 Oral 🔥) Improved Distribution Matching Distillation for Fast Image Synthesis
Other
526 stars 28 forks source link

When i train SDXL-4steps model, i get OOM error, i use a single A100. #19

Closed bengen-y closed 6 months ago

bengen-y commented 6 months ago

It seems that fake-unet occupies too much gpu memory,>13G, batch_size=1.

tianweiy commented 6 months ago

Thank you for the interest. In our experiments we need at least two gpus so that the FSDP can shard the parameters across them to reduce the memory usage.

1 GPU might be possible if do some further gradient checkpointing + real unet offloading. I can try it out some time tomorrow. But it might be still hard

bengen-y commented 6 months ago

Thank you, good job!

Thank you for the interest. In our experiments we need at least two gpus so that the FSDP can shard the parameters across them to reduce the memory usage.

1 GPU might be possible if do some further gradient checkpointing + real unet offloading. I can try it out some time tomorrow. But it might be still hard