owkin / HistoSSLscaling

Code associated to the publication: Scaling self-supervised learning for histopathology with masked image modeling, A. Filiot et al., MedRxiv (2023). We publicly release Phikon 🚀
Other
142 stars 11 forks source link

Slicing of scores variable in Chowder model leading to shape mismatch in forward pass #18

Closed VolodymyrChapman closed 9 months ago

VolodymyrChapman commented 10 months ago

Issue description:

Slicing of scores tensor on line 168 of chowder.py leads to shape mismatch and failing forward propagation. Reason for the slicing is not clear. https://github.com/owkin/HistoSSLscaling/blob/e66c7ca4392449f56b54645acc1a0dc0f42a868b/rl_benchmarks/models/slide_models/chowder.py#L168 Example error with input of Phikon output (1,50, 768) (batch, patches, phikon features):

chowder_kwargs = {'in_features': 768, 'out_features' : 3, 'n_top': 10, 'n_bottom': 10, 'mlp_hidden': [192, 96],  'mlp_dropout':[0.3, 0.3]}
model = Chowder(**chowder_kwargs)
model.eval()
model(features)
Cell In[35], [line 28](vscode-notebook-cell:?execution_count=35&line=28)
     [25](vscode-notebook-cell:?execution_count=35&line=25) model = Chowder(**chowder_kwargs)
     [27](vscode-notebook-cell:?execution_count=35&line=27) model.eval()
---> [28](vscode-notebook-cell:?execution_count=35&line=28) model(features)

File [~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1194](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1194), in Module._call_impl(self, *input, **kwargs)
   [1190](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1190) # If we don't have any hooks, we want to skip the rest of the logic in
   [1191](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1191) # this function, and just call forward.
   [1192](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1192) if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   [1193](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1193)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1194](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1194)     return forward_call(*input, **kwargs)
   [1195](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1195) # Do not call functions when jit is used
   [1196](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1196) full_backward_hooks, non_full_backward_hooks = [], []

Cell In[34], [line 157](vscode-notebook-cell:?execution_count=34&line=157)
    [141](vscode-notebook-cell:?execution_count=34&line=141) def forward(
    [142](vscode-notebook-cell:?execution_count=34&line=142)     self, features: torch.Tensor, mask: Optional[torch.BoolTensor] = None
    [143](vscode-notebook-cell:?execution_count=34&line=143) ) -> torch.Tensor:
    [144](vscode-notebook-cell:?execution_count=34&line=144)     """
    [145](vscode-notebook-cell:?execution_count=34&line=145)     Parameters
    [146](vscode-notebook-cell:?execution_count=34&line=146)     ----------
   (...)
    [155](vscode-notebook-cell:?execution_count=34&line=155)         (B, OUT_FEATURES), (B, N_TOP + N_BOTTOM, OUT_FEATURES)
    [156](vscode-notebook-cell:?execution_count=34&line=156)     """
--> [157](vscode-notebook-cell:?execution_count=34&line=157)     scores = self.score_model(x=features[..., 3:], mask=mask)
    [158](vscode-notebook-cell:?execution_count=34&line=158)     # scores = self.score_model(x=features, mask=mask)
    [159](vscode-notebook-cell:?execution_count=34&line=159)     print('scores shape:', scores.shape)

File [~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1194](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1194), in Module._call_impl(self, *input, **kwargs)
   [1190](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1190) # If we don't have any hooks, we want to skip the rest of the logic in
   [1191](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1191) # this function, and just call forward.
   [1192](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1192) if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   [1193](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1193)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1194](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1194)     return forward_call(*input, **kwargs)
   [1195](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1195) # Do not call functions when jit is used
   [1196](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1196) full_backward_hooks, non_full_backward_hooks = [], []

File [~/GitHub/HistoSSLscaling/rl_benchmarks/models/slide_models/utils/tile_layers.py:148](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/GitHub/HistoSSLscaling/rl_benchmarks/models/slide_models/utils/tile_layers.py:148), in TilesMLP.forward(self, x, mask)
    [146](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/GitHub/HistoSSLscaling/rl_benchmarks/models/slide_models/utils/tile_layers.py:146)         x = layer(x, mask)
    [147](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/GitHub/HistoSSLscaling/rl_benchmarks/models/slide_models/utils/tile_layers.py:147)     else:
--> [148](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/GitHub/HistoSSLscaling/rl_benchmarks/models/slide_models/utils/tile_layers.py:148)         x = layer(x)
    [149](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/GitHub/HistoSSLscaling/rl_benchmarks/models/slide_models/utils/tile_layers.py:149) return x

File [~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1194](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1194), in Module._call_impl(self, *input, **kwargs)
   [1190](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1190) # If we don't have any hooks, we want to skip the rest of the logic in
   [1191](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1191) # this function, and just call forward.
   [1192](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1192) if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   [1193](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1193)         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1194](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1194)     return forward_call(*input, **kwargs)
   [1195](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1195) # Do not call functions when jit is used
   [1196](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/module.py:1196) full_backward_hooks, non_full_backward_hooks = [], []

File [~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/linear.py:114](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/linear.py:114), in Linear.forward(self, input)
    [113](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/linear.py:113) def forward(self, input: Tensor) -> Tensor:
--> [114](https://file+.vscode-resource.vscode-cdn.net/storage_ssd/Data_storage/2024_01_11_phikon_survival_pred/~/miniconda3/envs/rl_benchmarks/lib/python3.8/site-packages/torch/nn/modules/linear.py:114)     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (50x765 and 768x3)

Suggested solution:

Changing line 168 to the below resolves the issue: scores = self.score_model(x=features, mask=mask)

Output when running the above example with modified line 168:

tensor([[-0.0729, -0.0908, -0.0894]], grad_fn=<SqueezeBackward1>)

jbschiratti commented 10 months ago

Use this instead.

scores = self.score_model(x=features, mask=mask)

The processed data released by Owkin are arrays with shape (n_tiles, 771). The first 3 columns of each features array are metadata (i.e. zoom level, 1st coordinate of tile address, 2nd coordinate of tile address). If your data does not have these 3 additional columns, then the slicing is not needed. In fact, it should be optional 🐛

VolodymyrChapman commented 10 months ago

Thanks for the quick response :-) Ahhh, that makes sense. Making slicing an optional argument within the Chowder class sounds like the best way to maintain essential behaviour for Owkin and ease of use, if using ibot_vit as a feature extractor :-) Please let me know if you would be open to a pull request - can do on Monday. Best wishes, V

jbschiratti commented 10 months ago

Yes, please. PRs are very welcome

VolodymyrChapman commented 9 months ago

Great - PR submitted :-)

VolodymyrChapman commented 9 months ago

FYI for future users: Addressed in PR: https://github.com/owkin/HistoSSLscaling/pull/19

afilt commented 9 months ago

Resolved in #19.