lucidrains / recurrent-interface-network-pytorch

Implementation of Recurrent Interface Network (RIN), for highly efficient generation of images and video without cascading networks, in Pytorch
MIT License
194 stars 14 forks source link

Min SNR weighting implemented as max SNR weighting #11

Closed justinlovelace closed 1 year ago

justinlovelace commented 1 year ago

Hey Phil! I believe that the Min SNR weighting is currently implemented as

loss_weight = max(snr, self.min_snr_gamma)

instead of loss_weight = min(snr, self.min_snr_gamma)

I think line 839 should be maybe_clipped_snr.clamp_(max = self.min_snr_gamma)

https://github.com/lucidrains/recurrent-interface-network-pytorch/blob/77e4ced5cd8d091acae31ad2cb19a32f64a4eb3a/rin_pytorch/rin_pytorch.py#L833-L850

Apologies if I'm misunderstanding something 🙂

lucidrains commented 1 year ago

@justinlovelace oh yes you are right! thank you for finding this so quickly

ugh i committed this error in a number of other repos too

justinlovelace commented 1 year ago

Happy to help! Was just trying the new weighting scheme out for some of my work and noticed the issue

lucidrains commented 1 year ago

thank you Justin! and if Lovelace is your last name, you'd be the first Lovelace I've met (besides knowing the famous Ada)

justinlovelace commented 1 year ago

It is! I suppose it's a fitting name to have in this field 🙂