AIH-SGML / mixmil

Code for the paper: Mixed Models with Multiple Instance Learning
https://arxiv.org/abs/2311.02455
Apache License 2.0
13 stars 0 forks source link

How to train MixMIL with categorical data? #13

Closed patricks-lab closed 3 weeks ago

patricks-lab commented 1 month ago

Thanks for the great work (and congrats for being selected as an oral for this year's AISTATS)!

I'm trying to train a mixMIL model on categorical data.

As I had trouble getting it to work on my own data, I tried it on the "mock" test data. I adapted the code from the mock_data_categorical function in tests (which only initializes a mean model but doesn't train it):

https://github.com/AIH-SGML/mixmil/blob/74b0940e2bee39f0ca2de1d46111608beaa756c3/tests/test_model.py#L17-L23

and wrote my own mock data training script:

from mixmil import MixMIL
import torch 

device = "cuda:1" if torch.cuda.is_available() else "cpu"

N, Q, K = 50, 10, 4
bag_sizes = torch.randint(5, 15, (N,))
Xs = [torch.randn(bag_sizes[n], Q) for n in range(N)]  # List of tensors
F = torch.randn(N, K)  # Fixed effects
Y = torch.randint(0, 5, (N, 1))  # Labels for categorical
model = MixMIL.init_with_mean_model(Xs, F, Y, likelihood="categorical")
model.train(Xs, F, Y, n_epochs=10)

However, even with the mock data I run into a runtime error in model.train(): Screenshot 2024-07-08 at 1 21 24 PM

When I print the shapes of the offending scale_u and scale_z, I get scale_u to be of shape [5,10] and scale_z to be of shape [1,10]. Thus we can't torch.cat over dim=1 (since dim 0 is different).

Also, printing the model gives me the following: Screenshot 2024-07-08 at 2 43 19 PM

I would greatly appreciate any tips to get mixMIL to train on categorical data (e.g. fixes for the above scenario or sample code).

Thanks in advance! Patrick

stephandooper commented 1 month ago

Hi,

I encountered the same problem as you did. I think the problem comes from the log_sigma_z having the wrong dimension, since it's (1,1) (from your model print) while I think it should be (1, 5).

Looking into it a bit further, here is where it goes wrong. The first culprit is in the instantiation of MixMIL itself. This isn't exactly a bug, but here the variable var_z is overwritten by the contents of init_params.

https://github.com/AIH-SGML/mixmil/blob/74b0940e2bee39f0ca2de1d46111608beaa756c3/mixmil/model.py#L46

Now the init_params are actually calculated in the utils scripts, and there is one for Binomial and Categorical distributions. Zooming in on the categorical distribution, we can find the following code.

https://github.com/AIH-SGML/mixmil/blob/74b0940e2bee39f0ca2de1d46111608beaa756c3/mixmil/utils.py#L79

Now if you run this in your example, you will see that both mu_beta and sd_beta will have shape [10, 5]. However, we are then computing a mean over both dimensions, which I think is wrong in this case.

Possible solutions So I think there are 2 solutions you can use here:

  1. Simply don't use the init_with_mean_model code, and rather start from random weights using the regular init function. or
  2. Adjust the computation of the mean to take into account class dimensions, like so:
    num_classes = mu_beta.shape[1]
    var_z = (mu_beta**2 + sd_beta**2).mean(axis=0).reshape(1, num_classes)

Full disclosure, I don't know if this last solution is correct.

jan-engelmann commented 1 month ago

Hi you two,

thanks a lot for pointing this out! @stephandooper's solution is correct. I have updated the code to fix the bug. See PR #15

Could you do pip install mixmil --upgrade to see if v0.1.2 fixes the issue for you?

Thanks again!

Cheers, Jan

jan-engelmann commented 3 weeks ago

Feel free to re-open the issue if it persists!