mahmoodlab / HIPT

Hierarchical Image Pyramid Transformer - CVPR 2022 (Oral)
Other
509 stars 90 forks source link

subtyping: training loss not really decreasing #25

Closed clemsgrs closed 1 year ago

clemsgrs commented 2 years ago

Hi, I'm trying to replicate the subtyping results you report on TCGA BRCA as a sanity check before applying HIPT to a different dataset. Hence, I'm using the same slides, same splits & same labels as given in the repo. For now, I've sticked to training & evaluating on fold_0.

I'm having troubles when training a model: my training loss barely goes down (see picture below: loss plateaus after epoch 6, with training AUC being around 0.50).

W B Chart 09_11_2022, 20_44_57

After having deeply dived into the code, there are a few things I'd love to have your help to understand:

In HIPT_LGP_FC you set self.local_vit = vit4k_xs() ; based on the following lines, it means self.local_vit is an instance of VisionTransformer4K with patch_size = 16 https://github.com/mahmoodlab/HIPT/blob/2e0adbe943175bcc13327a4f2e8785b59d6c6249/HIPT_4K/vision_transformer4k.py#L267-L272 Then, looking at the VisionTransformer4K class, the default img_size argument is [224]. Combined with patch_size = 16, this means that num_patches = 196 (line 170), which is used line 174 to instantiate self.pos_embed https://github.com/mahmoodlab/HIPT/blob/2e0adbe943175bcc13327a4f2e8785b59d6c6249/HIPT_4K/vision_transformer4k.py#L161-L174 Hence, if we feed HIPT_LGP_FC a tensor of shape [M, 256, 384] as done in the model walkthrough notebook, at some point during the forward pass, the interpolate_pos_encoding method gets called. Given x.shape = [M, 257, 384] and pos_embed.shape = [1, 197, 192], npatch = 256 and N = 196: the condition npatch == N on line 204 is False, so we need interpolate the positional embedding https://github.com/mahmoodlab/HIPT/blob/2e0adbe943175bcc13327a4f2e8785b59d6c6249/HIPT_4K/vision_transformer4k.py#L201-L205

  1. why the patch_size argument passed when instantiating VisionTransformer4K is actually not used in VisionTransformer4K.__init__() -- instead, a hard coded value of 16 is used (line 170, see below) https://github.com/mahmoodlab/HIPT/blob/2e0adbe943175bcc13327a4f2e8785b59d6c6249/HIPT_4K/vision_transformer4k.py#L170

  2. why the img_size argument passed when instantiating VisionTransformer4K is left as default (i.e. img_size = [224]) and not set to [256]? I get that during self-supervised pre-training, you use crops of size [224, 224], but during subtyping, we're using the full [256, 256] patch, so I guess we should use img_size = [256], shouldn't we? Doing so, the previously discussed condition npatch == N would become True (hence we would not need to interpolate the positional embedding anymore).

  3. Given we pass a tensor of shape [M, 256, 384] to HIPT_LGP_FC, which get reshaped to [M, 384, 16, 16] before being passed to HIPT_LGP_FC.local_vit, the following line gives B = M. https://github.com/mahmoodlab/HIPT/blob/2e0adbe943175bcc13327a4f2e8785b59d6c6249/HIPT_4K/vision_transformer4k.py#L226 Then, in the following line we define cls_token as a tensor of shape [M, 1, 192]. Isn't there a confusion between B (supposed to account for the batch size) and M (number of [4096, 4096] regions per slide)? Shouldn't the cls_token tensor be of shape [batch_size, 1, 192] ? https://github.com/mahmoodlab/HIPT/blob/2e0adbe943175bcc13327a4f2e8785b59d6c6249/HIPT_4K/vision_transformer4k.py#L232-L233

  4. I've also tried training only the global aggregation layers by directly feeding the region-level pre-extracted features (of shape [M, 192]), without success (training loss not really decreasing either). Could you confirm that this should work just as well as training the intermediate transformer + the global aggregation layers on the [M, 256, 384] features?

Thanks!

Richarizardd commented 2 years ago

Hi @clemsgrs - thank you for the detailed post, and will do my best to respond.

  1. Why is patch_size argument passed when instantiating VisionTransformer4K but not used?

Apologies, but it was a typo. I wanted to make the token size an argument for instantiating VisionTransformer4K, but ended up making the arguments for VisionTransformer4K the same as the regular VisionTransformer class as there is no change in ViT sequence length complexity. Whether it is 256-sized images with 16-size patching or 4096-sized images with 256-sized patching, the sequence length is always 16*16=256. To be more exact, technically the image size for VisionTransformer4K should be 3584 while the patch size is 256, as during pretraining, as the maximum global crop size is [14 x 14] in a [16 x 16 x 384] 2D grid of pre-extracted feature embeddings of 256-sized patches. However, since VisionTransformer4K doesn't actually take in 3584/4096-sized images but rather the 2D grid of pre-extracted feature embeddings, it was easier to keep VisionTransformer4K the same as VisionTransformer. I will fix some of these arguments so that it is less confusing.

  1. Why is the img_size argument passed when instantiating VisionTransformer4K is left as default (i.e. img_size = [224]) and not set to [256]...?

See above comment. In addition, I would note that as in the original VisionTransformer from DINO, we can't set img_size = [256] for instantiating VisionTransformer4K as the images are trained with img_size = [224] and thus, the sequence length in self.pos_embed is (224/16)**2+1 = 197. If you change img_size=[256], you would not be able to load in the pretrained weights. Despite some of the typos, everything ended up working in a roundabout way as the ViT complexities are consistent across image resolutions, but apologies for confusion!

  1. Isn't there a confusion between B (supposed to account for the batch size) and M (number of [4096, 4096] regions per slide)? Shouldn't the cls_token tensor be of shape [batch_size, 1, 192]?

I am not sure what the confusion is and may need clarification. However, I would say that in training the local aggregation (of 256-sized features to learn 4K-sized features) in HIPT, you can treat the number of [4096, 4096] regions essentially like a "minibatch" in processing all [M x 256 x 384] features at once. The actual "batch size" (# of WSIs) for weakly-supervised learning is 1.

  1. I've also tried training only the global aggregation layers by directly feeding the region-level pre-extracted features (of shape [M, 192]), without success (training loss not really decreasing either). Could you confirm that this should work just as well as training the intermediate transformer + the global aggregation layers on the [M, 256, 384] features?

I am sorry that you have not had success using the available region-level pre-extracted feature embeddings. What weakly-supervised scaffold code did you use? In this work, CLAM was used for weakly-supervised learning, which I slightly modified for HIPT. Here are the following areas in the repository that may help you in understanding the loss curves and reproducibility.

What problems are you looking to apply HIPT too? I appreciated reading you detailed response in getting this method to work correctly, I would be happy to understand and work through any pain points you have in using this method on TCGA (and other downstream tasks).

clemsgrs commented 2 years ago

Hi @Richarizardd, thank you for answering so quickly & with details!

  1. ok makes sense!
  2. indeed, when using img_size = [256], all pre-trained weights are nicely loaded, except the positional embedding (because of the mismatching shape). I just realised that, in that case, my positional embedding will be a random tensor during the whole training process (given it gets initialised as such & then gets frozen). I'll stick to img_size = [224] for now.
  3. true, the confusion was mine! I thought this was happening in the last transformer block (aggregating region-level features into a single slide-level representation), but as you pointed out it's happening in intermediate transformer block, where it makes total sense to have 1 cls_token per region.
  4. as a weakly-supervised scaffold code, I'm using the same modified CLAM scaffold as you used. Thanks for having added the training commands & the HIPT_GP_FC model (I had implemented the same a model). I was using a different set of pre-extracted features than the ones you provide via git LFS (I slightly adapted CLAM patch extraction pipeline). → I'll switch to using yours as a first step. Thank you for the list of resources I can use to debug what is happening on my side. I had forgotten about the Self-Supervised KNN, it seems a good place to start!

I'm looking to apply HIPT to various computational pathology problems and compare how other methods perform on the same tasks.

Once again, big thanks for the fast & detailed answer. Will reach back to you when I have something new!

clemsgrs commented 1 year ago

Finally got it working!

working_curves

The issue was coming from me using img_size = [256] instead of the [224] value when instantiating VisionTransformer / VisionTransformer4K components in HIPT model. This caused the pre-trained weights for the positional embedding parameter to be skipped when loading state dict because of mismatching shape. As a results, when generating region-level features, positional embeddings were left as initialised, that is random tensors (normal dist)! This caused my region-level features to be garbage...

I've re-generated the region-level features with img_size = [224] and now got decent loss profiles & AUC number, great!

Before I clause this issue, I had two small follow-up questions:

  1. I get a warning when training with [M, 256, 384] features that comes from interpolating the positional embeddings (line 214 to 218) https://github.com/mahmoodlab/HIPT/blob/b5f4844f2d8b013d06807375166817eeb939a5aa/HIPT_4K/vision_transformer4k.py#L214-L218 Here's the associated warning: UserWarning: Default upsampling behavior when mode=bicubic is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.

To supresss it, I would add align_corners=False, but wanted to make sure this was the behavior you would also go for.

  1. Just want to make sure I'm not missing something: could you confirm the only reason one would want to train a model on [M, 256, 384] features (i.e. a HIPT model with a pre-trained self.local_vit component) instead of training the global aggregation transformer on [M, 192] features is: fine-tune the pre-trained self.local_vit (by allowing gradients to flow through this component).

In case the pre-trained self.local_vit component gets frozen, the former should yield the same results as the later. But given it will have a longer forward pass (the features additionally have to go through self.local_vit), one should favour the later.

Thank you for your previous answer, it really helped me find where the issue was coming from! Now that it is fixed, I'll try to reproduce the experiments you report in the paper & look at the ones you've recently run and linked in your answer above. Will be interesting!

Richarizardd commented 1 year ago
  1. Yes - align_corners=False
  2. Though the "longer forward pass" that uses self.local_vit is more expensive to run, one can do more data augmentation via running self.local_vit with dropout. For larger datasets, finetuning self.local_vit may also be helpful. Lastly, another advantage is that with both features at 256- and 4096-level, one can also try exploring other variations such as concatenating: 1) slide feature from aggregating 256-level features via ABMIL, 2) slide feature from aggregating 4096-level features via ABMIL, and 3) slide feature from last Transformer. I have not tried other strategies, but seems intuitive for capturing the "different scales of features" across resolutions. Would be fun to also mix-and-match different aggregation functions.

Thank you for reporting these issues again. I will reflect these changes sometime this weekend.

clemsgrs commented 1 year ago

For keeping records, I got 0.883 ± 0.06​ AUC for the breast subtyping task (ILC vs. IDC) using the same dataset (same 875 slides, same 10 folds).

This is on par with the results reported in the paper (0.874 ± 0.06, see Table 1).

The slight difference comes from me using different region-level pre-extracted features: I slightly adapted CLAM patching code to generate [4096,4096] regions per slide, then used the provided pre-trained weights to produce region-level features of shape [M, 192]. For each slide, I have slightly different regions.

bryanwong17 commented 1 year ago

Hi @clemsgrs, I have a small question. When training at the slide level, did you set freeze_4k = True?

clemsgrs commented 1 year ago

hi @bryanwong17, when training on region-level features (i.e. sequence of embeddings shaped [M, 192]), I did set freeze_4k = True

bryanwong17 commented 1 year ago

Hi @clemsgrs, is the input of training HIPT_LGP_FC [M, 256, 384]? when we define 'pretrain_4k != None', it would load 'vit4k_xs_dino.pth' and change the dimension to [M. 192]? Then, we set 'freeze_4k=True'?