AIH-SGML / mixmil

Code for the paper: Mixed Models with Multiple Instance Learning
https://arxiv.org/abs/2311.02455
Apache License 2.0
15 stars 0 forks source link

Bug in KL scaling #8

Closed stephandooper closed 5 months ago

stephandooper commented 5 months ago

Hi, sorry for bothering again. However, I think I found a small bug within the KL divergence loss in the loss function.

Currently, the kl loss within the code is already weighted by the batch size divided by the total dataset (kld_w). So far so good.

However, dividing by y_shape[0] refers to the batch size, at least when I tried the Camelyon16 example, meaning that the KL loss would eventually be weighted by just the size of the dataset, which would make the KL loss very small.

Potential fix I guess y_shape[0] was meant to divide by the number of outputs P? In that case, I think it should be a simple fix and change it to: kld_term = kld_w * kld.sum() / y.shape[1]

or kld_term = kld_w * kld.sum() /self.P

to avoid any confusion about the dimensions

https://github.com/AIH-SGML/mixmil/blob/bae25eba1d2ece9d30df5d4c79e1676ba1989f19/mixmil/model.py#L103

jan-engelmann commented 5 months ago

Hi again :D

That scaling is correct. The important thing is that the relative scale of the LL and the KL term are the same as if you were training full batch. Now since we take the mean across the batch for the LL term, the LL is on the scale of a single observation. Therefore, we divide the KL, as you point out, by the number of samples in the dataset.

You can find a different but equivalent formulation of the mini-batched ELBO in this paper (eq 3).

Cheers, Jan