royerlab / cytoself

Self-supervised models for encoding protein localization patterns from microscopy images
BSD 3-Clause "New" or "Revised" License
74 stars 15 forks source link

UMap not looking good #27

Closed sofroniewn closed 1 year ago

sofroniewn commented 1 year ago

Hi @li-li-github - i was away for a while, but getting back to this now. I closed my other issues to keep things clean, but thanks for your help there!

I've been trying to retrain cytoself on the full data, and I can get to pretty good looking images, but when I run umap I don't see any structure. For example

image

and

image

I've trained ~17 epochs and get to a reconstructions loss around 0.05 for reconstructions_loss1.

I'm using my own trainer class and might have slightly different hyperparameters, but things look much worse than I would expect if I had things implemented right. Do you have target numbers for other training losses? Or any other ideas on what might be going wrong here. If easier to hop on a zoom call I can talk some time too.

Thanks!!!

sofroniewn commented 1 year ago

Some additional info - fc2_loss was around 0.72 and still going down - so I may have stopped training too soon. I'm doing a longer run now. My vq2_softmax_loss also started going up - I'm not sure if that is bad.

I'm also using coeffs of 1.0 for fc and vq. If I increase the fc coeff I guess I should get better at prediction - do you think that would mean a better umap? I still feel something more fundamental might be wrong though as the umap is so homogeneous

li-li-github commented 1 year ago

Hi, It looks like you're using the wrong latent vectors for the UMAP. How did you get the UMAP?

You can try to get the latent vectors first by latent_vec = trainer.infer_embeddings(<test_data>, 'vqvec2') and then run UMAP directly using the umap-learn package to see if you can get a correct UMAP.

The overall training looks good according to the reconstructed images. If you still couldn't get the right UMAP, we can also have a Zoom call. Feel free to email me for scheduling a call.

sofroniewn commented 1 year ago

I get ~ the same umap from either analysis.plot_umap_of_embedding_vector, trainer.infer_embeddings, or using trainer.model directly all with 'vqvec2'. Both using your trainer or using my own trainer at a point where the reconstructions look good. If I'm just using the wrong latents that would be great. I'm not sure what else to try next though. A zoom would be great. I can't find your email, so I'll just send you a DM on twitter and we can schedule there. Thanks!!

sofroniewn commented 1 year ago

To train the model, one approach I'm using is just your trainer with these parameters

model_args = {
    'input_shape': (2, 100, 100),
    'emb_shapes': ((25, 25), (4, 4)),
    'output_shape': (2, 100, 100),
    'fc_output_idx': [2],
    'vq_args': {'num_embeddings': 2048, 'embedding_dim': 64},
    'fc_args': {'num_layers': 2},
    'num_class': num_classes,
    'fc_input_type': 'vqvec',
}
train_args = {
    'lr': 1e-3,
    'max_epoch': 30,
    'reducelr_patience': 3,
    'reducelr_increment': 0.1,
    'earlystop_patience': 6,
}
trainer = CytoselfFullTrainer(train_args, homepath='demo_output3', model_args=model_args)
trainer.fit(pdm, tensorboard_path='tb_logs')

I ran the above for 10 epochs and got the following performance curves

image

This level of reconstruction

image

and this umap

image

(not the blue is the "other" category)

sofroniewn commented 1 year ago

Here is some of the trainer history

train_loss train_fc2_loss train_perplexity1 train_perplexity2 train_reconstruction1_loss train_reconstruction2_loss train_vq1_loss train_vq1_commitment_loss train_vq1_quantization_loss train_vq1_softmax_loss ... val_reconstruction2_loss val_vq1_loss val_vq1_commitment_loss val_vq1_quantization_loss val_vq1_softmax_loss val_vq2_loss val_vq2_commitment_loss val_vq2_quantization_loss val_vq2_softmax_loss lr
5.856858 5.403948 932.426041 338.800671 0.190714 0.062433 0.153326 0.122660 0.122660 3.639950 ... 0.036588 0.076991 0.061593 0.061593 4.007305 0.053910 0.043128 0.043128 4.942319 0.0010
3.863978 3.719815 791.804969 503.562573 0.071442 0.023739 0.012655 0.010124 0.010124 4.639841 ... 0.021032 0.009363 0.007491 0.007491 4.949356 0.049799 0.039839 0.039839 2.312072 0.0010
3.190841 3.056592 987.251616 503.740017 0.054630 0.019380 0.009005 0.007204 0.007204 5.067519 ... 0.019367 0.008913 0.007130 0.007130 5.195401 0.055444 0.044355 0.044355 1.993897 0.0010
2.888193 2.758002 1105.496906 493.675083 0.048800 0.018582 0.008659 0.006928 0.006928 5.218154 ... 0.017192 0.007986 0.006389 0.006389 5.331193 0.060996 0.048797 0.048797 1.786524 0.0010
2.717702 2.588516 1167.443718 490.424035 0.046288 0.018400 0.008744 0.006995 0.006995 5.276464 ... 0.018189 0.008493 0.006794 0.006794 5.355626 0.063466 0.050772 0.050772 1.790967 0.0010
2.607975 2.477561 1192.986759 488.258119 0.044662 0.018399 0.008899 0.007119 0.007119 5.300541 ... 0.018378 0.009024 0.007219 0.007219 5.336828 0.071301 0.057040 0.057040 1.711513 0.0010
2.533543 2.403831 1213.297743 485.345133 0.043448 0.018335 0.008937 0.007149 0.007149 5.318454 ... 0.018942 0.008917 0.007134 0.007134 5.316447 0.071536 0.057229 0.057229 1.820205 0.0010
2.082423 1.952750 1253.019133 469.576832 0.037060 0.018221 0.009005 0.007204 0.007204 5.337982 ... 0.017593 0.008567 0.006854 0.006854 5.391302 0.084770 0.067816 0.067816 1.317909 0.0001
2.000487 1.850641 1257.308343 464.171818 0.035759 0.017754 0.008867 0.007094 0.007094 5.387692 ... 0.016672 0.008248 0.006598 0.006598 5.444560 0.107130 0.085704 0.085704 0.897221 0.0001
1.977823 1.807582 1258.851098 462.197427 0.035208 0.017314 0.008660 0.006928 0.006928 5.427403 ... 0.016944 0.008341 0.006673 0.006673 5.465038 0.130319 0.104255 0.104255 0.766412 0.0001
sofroniewn commented 1 year ago

One thing I just noticed is that my embeddings have shape (N, 64, 4, 4), but if I read your paper I see this

our model encodes two representations for each image that correspond to two different spatial scales, the local and global representations, that correspond to VQ1 and VQ2, respectively. The global representation captures large-scale image structure scaled-down to a 4 × 4 pixel image with 576 features (values) per pixel. The local representation captures finer spatially resolved details (25 × 25 pixel image with 64 features per pixel).

Should I have embeddings that are (N, 576, 4, 4) somewhere? I don't know where the 576 number comes from?

sofroniewn commented 1 year ago

If I use channel_split=9 in the vq_args then my embeddings have shape (N, 576, 4, 4), but I fill up my gpu much more easily so will need to use a smaller batch size and training will be slower. Do you think this is important? What is the channel_split parameter, I couldn't find much info about it

sofroniewn commented 1 year ago

One thing also is that right now the images in my batch are shape (B, 2, 100, 100) where the 0 channel is the protein and the 1 channel is the nucleus stain. Should I be giving the model both protein and nucleus? Do i need to specify which channel is protein for the fc layer? Something might be going wrong here

sofroniewn commented 1 year ago

@li-li-github - there is definitely an issue with the dataloading strategy. I havn't quite figured out what's going on yet, but it looks like I get different umaps depending on the batch size and composition - for example, using your dataloader if I use a batch size of 1 then I get a bad umap (looks like mine), but if I use a batch size of 32 then I get a good one. Note this is the batch size at inference time, not at training. I will keep looking into this, but any ideas why the embedding results might be effected by the batch?

sofroniewn commented 1 year ago

image

Success! The ONLY difference in my code between generating this umap and the bad one above is shuffle=True in the dataloader for my test set at inference time. I had shuffle=False for my test set by default, and after a lot of searching narrowed down the issue to this parameter.

I'm guess these is something about the vq-VAE architecture where the exact composition of the batch determines what embedding vector each image gets mapped too and when all the images in the batch were of the same protein things (as no shuffling) things didn't map as expected. This seems like a limitation and something that is quite brittle. I will look more into the code and papers to understand more. I'd really like to understand what's going on here. I wonder if modifications could be made so that the model was no longer sensitive to the batch composition.

I'd love to get your thoughts here @li-li-github ? I'm curious if @royerloic might have an idea too - probably good to be aware of as well

sofroniewn commented 1 year ago

Ah so turns out I had also forgotten to put the model in eval mode when computing embeddings. This has fixed the dependence on the shuffle (for reasons though that I don't quite understand). I now get the following, which looks ok, but not as great as you had in the paper. I guess now though I just need to train more. I assume that you put the model in eval mode when computing embeddings

image

I think I'm fairly close now to having this all working right, but any additional insights @li-li-github would be appreciated. If easier we can talk one more time. Thanks again for your help here

sofroniewn commented 1 year ago

Ok - I think I'm in business now, getting pretty good results here with more training and eval mode

image

JoOkuma commented 1 year ago

FYI. As we discussed in person with @sofroniewn. The results were different when shuffle was on and off in training mode because the batch norm was being updated.

sofroniewn commented 1 year ago

Thanks! All clear now