Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.91k stars 1.09k forks source link

Allow an arbitrary mask to be used in the self attention #8235

Closed Lucas-rbnt closed 58 minutes ago

Lucas-rbnt commented 3 days ago

Description

The aim of this PR is to enable the use of an arbitrary mask in the self attention module, which is very useful in the case of missing data or masked modeling.

Official torch implementations allow the use of an arbitrary mask, and in MONAI the use of a mask is also made possible with the causal argument. Here, it's just a generalization directly in the forward pass.

In the SABlock and TransformerBlock, it is now possible to input a boolean mask of size (BS, Seq_length). Only the columns of the masked token are set to -inf and not the rows, as is rarely the case in common implementations. Masked tokens don't contribute to the gradient anyway. In cases where causal attention is required, inputting a mask is not supported to avoid masks overlapping.

I haven't implemented the addition mask to the attention matrix, which allows you to use values other than -inf in certain cases, as may be the case here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

If you think it's relevant, it could be added.

Types of changes

ericspod commented 3 days ago

I think this is fine with the minor proposed change.

KumoLiu commented 1 day ago

/build

KumoLiu commented 1 day ago

It seems there is a TorchScript conversion issue caused by this addition.

======================================================================
[2024-11-24T08:22:21.492Z] ERROR: test_script_0 (tests.test_selfattention.TestResBlock)
[2024-11-24T08:22:21.492Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.492Z] Traceback (most recent call last):
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/parameterized/parameterized.py", line 620, in standalone_func
[2024-11-24T08:22:21.492Z]     return func(*(a + p.args), **p.kwargs, **kw)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/test_selfattention.py", line 215, in test_script
[2024-11-24T08:22:21.492Z]     test_script_save(net, test_data)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/utils.py", line 745, in test_script_save
[2024-11-24T08:22:21.492Z]     convert_to_torchscript(
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/utils.py", line 796, in convert_to_torchscript
[2024-11-24T08:22:21.492Z]     script_module = torch.jit.script(model, **kwargs)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 1324, in script
[2024-11-24T08:22:21.492Z]     return torch.jit._recursive.create_script_module(
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 559, in create_script_module
[2024-11-24T08:22:21.492Z]     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
[2024-11-24T08:22:21.492Z]     create_methods_and_properties_from_stubs(
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
[2024-11-24T08:22:21.492Z]     concrete_type._create_methods_and_properties(
[2024-11-24T08:22:21.492Z] RuntimeError: 
[2024-11-24T08:22:21.492Z] Expression of type | cannot be used in a type expression:
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/blocks/selfattention.py", line 157
[2024-11-24T08:22:21.492Z]     def forward(self, x, attn_mask: torch.Tensor | None = None):
[2024-11-24T08:22:21.492Z]                                     ~~~~~~~~~~~~~~~~~~~ <--- HERE
[2024-11-24T08:22:21.492Z]         """
[2024-11-24T08:22:21.492Z]         Args:
[2024-11-24T08:22:21.492Z] 
[2024-11-24T08:22:21.492Z] 
[2024-11-24T08:22:21.492Z] ======================================================================
[2024-11-24T08:22:21.492Z] ERROR: test_script_1 (tests.test_selfattention.TestResBlock)
[2024-11-24T08:22:21.492Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.492Z] Traceback (most recent call last):
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/parameterized/parameterized.py", line 620, in standalone_func
[2024-11-24T08:22:21.492Z]     return func(*(a + p.args), **p.kwargs, **kw)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/test_selfattention.py", line 215, in test_script
[2024-11-24T08:22:21.492Z]     test_script_save(net, test_data)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/utils.py", line 745, in test_script_save
[2024-11-24T08:22:21.492Z]     convert_to_torchscript(
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/utils.py", line 796, in convert_to_torchscript
[2024-11-24T08:22:21.492Z]     script_module = torch.jit.script(model, **kwargs)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 1324, in script
[2024-11-24T08:22:21.492Z]     return torch.jit._recursive.create_script_module(
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 559, in create_script_module
[2024-11-24T08:22:21.492Z]     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
[2024-11-24T08:22:21.492Z]     create_methods_and_properties_from_stubs(
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
[2024-11-24T08:22:21.492Z]     concrete_type._create_methods_and_properties(
[2024-11-24T08:22:21.492Z] RuntimeError: 
[2024-11-24T08:22:21.492Z] Expression of type | cannot be used in a type expression:
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/blocks/selfattention.py", line 157
[2024-11-24T08:22:21.492Z]     def forward(self, x, attn_mask: torch.Tensor | None = None):
[2024-11-24T08:22:21.492Z]                                     ~~~~~~~~~~~~~~~~~~~ <--- HERE
[2024-11-24T08:22:21.492Z]         """
[2024-11-24T08:22:21.492Z]         Args:
[2024-11-24T08:22:21.492Z] 
[2024-11-24T08:22:21.492Z] 
[2024-11-24T08:22:21.492Z] ======================================================================
[2024-11-24T08:22:21.492Z] ERROR: test_script_2 (tests.test_selfattention.TestResBlock)
[2024-11-24T08:22:21.492Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.492Z] Traceback (most recent call last):
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/parameterized/parameterized.py", line 620, in standalone_func
[2024-11-24T08:22:21.492Z]     return func(*(a + p.args), **p.kwargs, **kw)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/test_selfattention.py", line 215, in test_script
[2024-11-24T08:22:21.492Z]     test_script_save(net, test_data)
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/utils.py", line 745, in test_script_save
[2024-11-24T08:22:21.492Z]     convert_to_torchscript(
[2024-11-24T08:22:21.492Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/utils.py", line 796, in convert_to_torchscript
[2024-11-24T08:22:21.492Z]     script_module = torch.jit.script(model, **kwargs)
[2024-11-24T08:22:21.492Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 1324, in script
[2024-11-24T08:22:21.493Z]     return torch.jit._recursive.create_script_module(
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 559, in create_script_module
[2024-11-24T08:22:21.493Z]     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
[2024-11-24T08:22:21.493Z]     create_methods_and_properties_from_stubs(
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
[2024-11-24T08:22:21.493Z]     concrete_type._create_methods_and_properties(
[2024-11-24T08:22:21.493Z] RuntimeError: Can't redefine method: forward on class: __torch__.monai.networks.blocks.selfattention.___torch_mangle_531.SABlock (of Python compilation unit at: 0x5eb6b10)
[2024-11-24T08:22:21.493Z] 
[2024-11-24T08:22:21.493Z] ======================================================================
[2024-11-24T08:22:21.493Z] ERROR: test_script_3 (tests.test_selfattention.TestResBlock)
[2024-11-24T08:22:21.493Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.493Z] Traceback (most recent call last):
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/parameterized/parameterized.py", line 620, in standalone_func
[2024-11-24T08:22:21.493Z]     return func(*(a + p.args), **p.kwargs, **kw)
[2024-11-24T08:22:21.493Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/test_selfattention.py", line 215, in test_script
[2024-11-24T08:22:21.493Z]     test_script_save(net, test_data)
[2024-11-24T08:22:21.493Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/tests/utils.py", line 745, in test_script_save
[2024-11-24T08:22:21.493Z]     convert_to_torchscript(
[2024-11-24T08:22:21.493Z]   File "/home/jenkins/agent/workspace/MONAI-premerge/monai/monai/networks/utils.py", line 796, in convert_to_torchscript
[2024-11-24T08:22:21.493Z]     script_module = torch.jit.script(model, **kwargs)
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_script.py", line 1324, in script
[2024-11-24T08:22:21.493Z]     return torch.jit._recursive.create_script_module(
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 559, in create_script_module
[2024-11-24T08:22:21.493Z]     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
[2024-11-24T08:22:21.493Z]     create_methods_and_properties_from_stubs(
[2024-11-24T08:22:21.493Z]   File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
[2024-11-24T08:22:21.493Z]     concrete_type._create_methods_and_properties(
[2024-11-24T08:22:21.493Z] RuntimeError: Can't redefine method: forward on class: __torch__.monai.networks.blocks.selfattention.SABlock (of Python compilation unit at: 0x5eb6b10)
[2024-11-24T08:22:21.493Z] 
[2024-11-24T08:22:21.493Z] ----------------------------------------------------------------------
[2024-11-24T08:22:21.493Z] Ran 15769 tests in 1527.821s
[2024-11-24T08:22:21.493Z] 
[2024-11-24T08:22:21.493Z] FAILED (errors=4, skipped=1100)
Lucas-rbnt commented 19 hours ago

It seems there is a TorchScript conversion issue caused by this addition.

It seems to be due to a typing error on the | character.

[2024-11-24T08:22:21.492Z] RuntimeError: [2024-11-24T08:22:21.492Z] Expression of type | cannot be used in a type expression:

This typing method is reserved for python versions >3.10, but it seems that python 3.9 is being used in the test environment.

[2024-11-24T08:22:21.492Z] File "/usr/local/lib/python3.9/dist-packages/torch/jit/_recursive.py", line 636, in

I used this notation because I thought I'd already seen it in MONAI. The problem will be solved either by upping the python version used, or by switching to an older typing syntax.

from typing import Optional
...
attn_mask: Optional[torch.tensor] = None

I can change, as you prefer!

KumoLiu commented 17 hours ago

I can change, as you prefer!

Yes, could you help convert this to the older typing syntax, as TorchScript does not support the | operator? Thanks.

KumoLiu commented 13 hours ago

/build