lightly-ai / lightly

A python library for self-supervised learning on images.
https://docs.lightly.ai/self-supervised-learning/
MIT License
2.84k stars 246 forks source link

Get embedding for Transformer Backbone in MAE and PMSN #1505

Closed ramdhan1989 closed 4 months ago

ramdhan1989 commented 4 months ago

Hi, How to get embedding from Transformers based backbone? I have trained it using MAE and PMSN (followed tutorial) but I got this error during inference and I would like to get embedding. for MAE error with input size torch.Size([1, 3, 224, 224]) `--------------------------------------------------------------------------- IndexError Traceback (most recent call last) Cell In[22], line 4 2 vectors_query = [] 3 for batch_num, (inputs, labels,filename) in enumerate(queryloader): ----> 4 out = backbone.encode(inputs).flatten().detach().numpy() 5 vectors_query.append(out) 6 filename_query.append(filename[0].split('\')[-1])

File ~\anaconda3\envs\cuts\lib\site-packages\lightly\models\modules\masked_autoencoder.py:289, in MAEBackbone.encode(self, images, idx_keep) 272 """Returns encoded class and patch tokens from images. 273 274 Args: (...) 286 287 """ 288 out = self.images_to_tokens(images, prepend_class_token=True) --> 289 return self.encoder(out, idx_keep)

File ~\anaconda3\envs\cuts\lib\site-packages\torch\nn\modules\module.py:1148, in Module._call_impl(self, *input, *kwargs) 1145 bw_hook = hooks.BackwardHook(self, full_backward_hooks) 1146 input = bw_hook.setup_input_hook(input) -> 1148 result = forward_call(input, *kwargs) 1149 if _global_forward_hooks or self._forward_hooks: 1150 for hook in (_global_forward_hooks.values(), *self._forward_hooks.values()):

File ~\anaconda3\envs\cuts\lib\site-packages\lightly\models\modules\masked_autoencoder.py:104, in MAEEncoder.forward(self, input, idx_keep) 102 if idx_keep is not None: 103 input = utils.get_at_index(input, idx_keep) --> 104 return self.ln(self.layers(self.dropout(input)))

File ~\anaconda3\envs\cuts\lib\site-packages\torch\nn\modules\module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\cuts\lib\site-packages\torch\nn\modules\container.py:139, in Sequential.forward(self, input) 137 def forward(self, input): 138 for module in self: --> 139 input = module(input) 140 return input

File ~\anaconda3\envs\cuts\lib\site-packages\torch\nn\modules\module.py:1148, in Module._call_impl(self, *input, *kwargs) 1145 bw_hook = hooks.BackwardHook(self, full_backward_hooks) 1146 input = bw_hook.setup_input_hook(input) -> 1148 result = forward_call(input, *kwargs) 1149 if _global_forward_hooks or self._forward_hooks: 1150 for hook in (_global_forward_hooks.values(), *self._forward_hooks.values()):

File ~\AppData\Roaming\Python\Python310\site-packages\torchvision\models\vision_transformer.py:113, in EncoderBlock.forward(self, input) 111 torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") 112 x = self.ln1(input) --> 113 x, = self.self_attention(query=x, key=x, value=x, need_weights=False) 114 x = self.dropout(x) 115 x = x + input

File ~\anaconda3\envs\cuts\lib\site-packages\torch\nn\modules\module.py:1151, in Module._call_impl(self, *input, *kwargs) 1149 if _global_forward_hooks or self._forward_hooks: 1150 for hook in (_global_forward_hooks.values(), *self._forward_hooks.values()): -> 1151 hook_result = hook(self, input, result) 1152 if hook_result is not None: 1153 result = hook_result

File ~\anaconda3\envs\cuts\lib\site-packages\torchsummary\torchsummary.py:19, in summary..register_hook..hook(module, input, output) 17 m_key = "%s-%i" % (class_name, module_idx + 1) 18 summary[m_key] = OrderedDict() ---> 19 summary[m_key]["input_shape"] = list(input[0].size()) 20 summary[m_key]["input_shape"][0] = batch_size 21 if isinstance(output, (list, tuple)):

IndexError: tuple index out of range`

for PMSN with input size torch.Size([1, 3, 224, 224]) `--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[17], line 4 2 vectors_query = [] 3 for batch_num, (inputs, labels,filename) in enumerate(queryloader): ----> 4 out = backbone(inputs).flatten().detach().numpy() 5 vectors_query.append(out) 6 filename_query.append(filename[0].split('\')[-1])

File ~\anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\py38\lib\site-packages\lightly\models\modules\masked_autoencoder.py:265, in MAEBackbone.forward(self, images, idx_keep) 246 def forward( 247 self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None 248 ) -> torch.Tensor: 249 """Returns encoded class tokens from a batch of images. 250 251 Args: (...) 263 264 """ --> 265 out = self.encode(images, idx_keep) 266 class_token = out[:, 0] 267 return class_token

File ~\anaconda3\envs\py38\lib\site-packages\lightly\models\modules\masked_autoencoder.py:289, in MAEBackbone.encode(self, images, idx_keep) 272 """Returns encoded class and patch tokens from images. 273 274 Args: (...) 286 287 """ 288 out = self.images_to_tokens(images, prepend_class_token=True) --> 289 return self.encoder(out, idx_keep)

File ~\anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py:1538, in Module._call_impl(self, *args, *kwargs) 1535 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1536 args = bw_hook.setup_input_hook(args) -> 1538 result = forward_call(args, *kwargs) 1539 if _global_forward_hooks or self._forward_hooks: 1540 for hook_id, hook in ( 1541 _global_forward_hooks.items(), 1542 *self._forward_hooks.items(), 1543 ):

File ~\anaconda3\envs\py38\lib\site-packages\lightly\models\modules\masked_autoencoder.py:104, in MAEEncoder.forward(self, input, idx_keep) 102 if idx_keep is not None: 103 input = utils.get_at_index(input, idx_keep) --> 104 return self.ln(self.layers(self.dropout(input)))

File ~\anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, *kwargs) 1496 # If we don't have any hooks, we want to skip the rest of the logic in 1497 # this function, and just call forward. 1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1499 or _global_backward_pre_hooks or _global_backward_hooks 1500 or _global_forward_hooks or _global_forward_pre_hooks): -> 1501 return forward_call(args, **kwargs) 1502 # Do not call functions when jit is used 1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\py38\lib\site-packages\torch\nn\modules\container.py:217, in Sequential.forward(self, input) 215 def forward(self, input): 216 for module in self: --> 217 input = module(input) 218 return input

File ~\anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py:1538, in Module._call_impl(self, *args, *kwargs) 1535 bw_hook = hooks.BackwardHook(self, full_backward_hooks, backward_pre_hooks) 1536 args = bw_hook.setup_input_hook(args) -> 1538 result = forward_call(args, *kwargs) 1539 if _global_forward_hooks or self._forward_hooks: 1540 for hook_id, hook in ( 1541 _global_forward_hooks.items(), 1542 *self._forward_hooks.items(), 1543 ):

File ~\anaconda3\envs\py38\lib\site-packages\torchvision\models\vision_transformer.py:113, in EncoderBlock.forward(self, input) 111 torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") 112 x = self.ln1(input) --> 113 x, = self.self_attention(x, x, x, need_weights=False) 114 x = self.dropout(x) 115 x = x + input

File ~\anaconda3\envs\py38\lib\site-packages\torch\nn\modules\module.py:1547, in Module._call_impl(self, *args, **kwargs) 1545 hook_result = hook(self, args, kwargs, result) 1546 else: -> 1547 hook_result = hook(self, args, result) 1549 if hook_result is not None: 1550 result = hook_result

File ~\anaconda3\envs\py38\lib\site-packages\torchsummary\torchsummary.py:22, in summary..register_hook..hook(module, input, output) 20 summary[m_key]["input_shape"][0] = batch_size 21 if isinstance(output, (list, tuple)): ---> 22 summary[m_key]["output_shape"] = [ 23 [-1] + list(o.size())[1:] for o in output 24 ] 25 else: 26 summary[m_key]["output_shape"] = list(output.size())

File ~\anaconda3\envs\py38\lib\site-packages\torchsummary\torchsummary.py:23, in (.0) 20 summary[m_key]["input_shape"][0] = batch_size 21 if isinstance(output, (list, tuple)): 22 summary[m_key]["output_shape"] = [ ---> 23 [-1] + list(o.size())[1:] for o in output 24 ] 25 else: 26 summary[m_key]["output_shape"] = list(output.size())

AttributeError: 'NoneType' object has no attribute 'size'`

Thank you

guarin commented 4 months ago

To get embeddings you can pass the images just to the MAEBackbone model:

>>> vit = torchvision.models.vit_b_32()
>>> model = MAEBackbone.from_vit(vit)
>>> images = torch.rand(1, 3, 224, 224)
>>> model(images).shape
torch.Size([1, 768])

Could you verify that your images have the correct type and shape?

ramdhan1989 commented 4 months ago

Oke, it works. Thank you