Closed Confusezius closed 2 years ago
Oh goodness, yes. I don’t know how I let this regression slip through.
Would you like to submit a PR with the change, I’ll happily merge it?
Thank you for the kind words and submitting this issue.
Fixed: https://github.com/nmichlo/disent/releases/tag/v0.5.1 (have also released v0.6.0 with MPI3D fixes)
Found an extra error in the minimal version and the function above:
Same error as you found https://github.com/nmichlo/disent/blob/d3aa2d30194735d13fb700ad9844e60d742ab914/disent/frameworks/vae/_weaklysupervised__adavae.py#L221
Should be torch.where(share_mask, <ave_posterior>, <orig_posterior>)
https://github.com/nmichlo/disent/blob/d3aa2d30194735d13fb700ad9844e60d742ab914/disent/frameworks/vae/_weaklysupervised__adavae.py#L326-L329
Great work with this repository!
A quick note: In the GVAE averaging operation, I assume that
https://github.com/nmichlo/disent/blob/d3aa2d30194735d13fb700ad9844e60d742ab914/disent/frameworks/vae/_weaklysupervised__adavae.py#L238
is supposed to be
right?