grad, = autograd.grad(
outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True
)
The calculation of gradients is memory inefficient and lacks support for flast-attention. Consequently, when training with the reg_loss, it becomes necessary to reduce the batch_size.
https://github.com/SHI-Labs/Smooth-Diffusion/blob/5522761bb68fcb6ac1cfaee5a5b855d4a56ea33f/train_smooth_diffusion.py#L312
The calculation of gradients is memory inefficient and lacks support for flast-attention. Consequently, when training with the reg_loss, it becomes necessary to reduce the batch_size.