Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.77k stars 1.07k forks source link

Multiple Classification Head CNN #7706

Closed abhisuri97 closed 5 months ago

abhisuri97 commented 5 months ago

Hello, I am attempting to modify DenseNet121 to have multiple classification heads. However, when I'm computing the loss function, I get an error message saying that one of the variables has been modified by an inplace operation. The error and my initial thoughts/attempts to address the issue are at the bottom.

Here is my code:

Minorly modified DenseNet definition

class DenseNet(nn.Module):
    """
    Densenet based on: `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993.pdf>`_.
    Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.
    This network is non-deterministic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below
    for more details:
    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms

    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of the input channel.
        out_channels: number of the output classes.
        init_features: number of filters in the first convolution layer.
        growth_rate: how many filters to add each layer (k in paper).
        block_config: how many layers in each pooling block.
        bn_size: multiplicative factor for number of bottle neck layers.
            (i.e. bn_size * k features in the bottleneck layer)
        act: activation type and arguments. Defaults to relu.
        norm: feature normalization type and arguments. Defaults to batch norm.
        dropout_prob: dropout rate after each dense layer.
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        init_features: int = 64,
        growth_rate: int = 32,
        block_config: Sequence[int] = (6, 12, 24, 16),
        bn_size: int = 4,
        act: str | tuple = ("relu", {"inplace": True}),
        norm: str | tuple = "batch",
        dropout_prob: float = 0.0,
    ) -> None:
        super().__init__()

        conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
        pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
        avg_pool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
            Pool.ADAPTIVEAVG, spatial_dims
        ]

        self.features = nn.Sequential(
            OrderedDict(
                [
                    ("conv0", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ("norm0", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=init_features)),
                    ("relu0", get_act_layer(name=act)),
                    ("pool0", pool_type(kernel_size=3, stride=2, padding=1)),
                ]
            )
        )

        in_channels = init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                spatial_dims=spatial_dims,
                layers=num_layers,
                in_channels=in_channels,
                bn_size=bn_size,
                growth_rate=growth_rate,
                dropout_prob=dropout_prob,
                act=act,
                norm=norm,
            )
            self.features.add_module(f"denseblock{i + 1}", block)
            in_channels += num_layers * growth_rate
            if i == len(block_config) - 1:
                self.features.add_module(
                    "norm5", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
                )
            else:
                _out_channels = in_channels // 2
                trans = _Transition(
                    spatial_dims, in_channels=in_channels, out_channels=_out_channels, act=act, norm=norm
                )
                self.features.add_module(f"transition{i + 1}", trans)
                in_channels = _out_channels

        # pooling and classification

        self.class_layers = nn.Sequential(
            OrderedDict(
                [
                    ("relu", get_act_layer(name=act)),
                    ("pool", avg_pool_type(1)),
                    ("flatten", nn.Flatten(1)),
                    ("out", nn.Linear(in_channels, out_channels)),
                ]
            )
        )
        print(self.class_layers)

        for m in self.modules():
            if isinstance(m, conv_type):
                nn.init.kaiming_normal_(torch.as_tensor(m.weight))
            elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
                nn.init.constant_(torch.as_tensor(m.weight), 1)
                nn.init.constant_(torch.as_tensor(m.bias), 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(torch.as_tensor(m.bias), 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        xp = self.features(x)
        # x = self.class_layers(xp)
        return xp

Custom DenseNet121 with two classification heads (one that outputs a class and another that outputs contrast status)

class CustomDenseNet121(DenseNet):
    def __init__(self, spatial_dims=3, in_channels=2, out_channels=3, contrast_classes=2):
        # Initialize the standard DenseNet121 with the usual parameters
        super().__init__(spatial_dims=spatial_dims, in_channels=in_channels, 
                         out_channels=out_channels,
                        init_features=64, growth_rate=32, block_config=(6,12,24,16))
        conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
        pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
        avg_pool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
            Pool.ADAPTIVEAVG, spatial_dims
        ]
        # Additional head for contrast status prediction
        self.contrast_head = nn.Sequential(
            OrderedDict(
                [
                    ("relu", get_act_layer(name=("relu", {"inplace": True}))),
                    ("pool", avg_pool_type(1)),
                    ("flatten", nn.Flatten(1)),
                    ("out", nn.Linear(1024, 2)),
                ]
            )
        )
        print(self.contrast_head)

    def forward(self, x):
        # Extract features from the DenseNet backbone
        features = super().forward(x)
        # Apply adaptive pooling to reshape output for linear layers
        # Reshape pooled output to (batch_size, num_features)        
        # Pass pooled features through the classification and contrast heads
        out_class = self.class_layers(features)
        out_contrast = self.contrast_head(features)

        return out_class, out_contrast

Training loop

val_interval = 1
best_metric = -1
best_metric_epoch = -1
patience = 0
max_patience = 20
writer = SummaryWriter(f"model-output/temprun")
epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    model.cuda()
lossfn = torch.nn.CrossEntropyLoss()
lossfn2 = torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, class_l, contrast_l = batch_data["img"].to(device), batch_data["label_y"], batch_data["contrast_present"]
        class_lt = class_l.type(torch.LongTensor)
        class_le = class_lt.to(device)

        contrast_lt = contrast_l.type(torch.LongTensor)
        contrast_le = contrast_lt.to(device)

        optimizer.zero_grad()
        out_class, out_contrast = model(inputs)
        # loss = multi_task_loss(out_class, class_l, out_contrast, contrast_l)
        losscl = lossfn(out_class, class_le)
        lossca = lossfn2(out_contrast, contrast_le)
        losscl.backward()
        lossca.backward()
        optimizer.step()
        print('got here')
        # epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            y_pred = torch.tensor([], dtype=torch.float32, device=device)
            y_pred_contrast = torch.tensor([], dtype=torch.float32, device=device)
            y = torch.tensor([], dtype=torch.long, device=device)
            for val_data in val_loader:
                val_images, val_labels, val_contrast = val_data["img"].to(device), val_data["label_y"].to(device), val_data["contrast_present"].to(device)
                y_pred = torch.cat([y_pred, model(val_images)[0]], dim=0)
                y = torch.cat([y, val_labels], dim=0)

            acc_value = torch.eq(y_pred.argmax(dim=1), y)
            acc_metric = acc_value.sum().item() / len(acc_value)
            y_onehot = [post_label(i) for i in decollate_batch(y, detach=False)]
            y_pred_act = [post_pred(i) for i in decollate_batch(y_pred, detach=False)]
            auc_result = monai.metrics.compute_roc_auc(torch.stack(y_pred_act), torch.stack(y_onehot), 
                average="micro")

            if auc_result > best_metric:
                best_metric = auc_result
                best_metric_epoch = epoch + 1
                patience = 0
                torch.save(model.state_dict(), f"SIIM-output/{basedir}/best_metric_model_classification3d_dict.pth")
                print("saved new best metric model")
            else:
                patience += 1
                torch.save(model.state_dict(), f"SIIM-output/{basedir}/epoch-{epoch+1}_metric_model_classification3d_dict.pth")
                if patience > max_patience:
                    print("NO IMPROVEMENT")
                    break
            print(
                "current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best AUC: {:.4f} at epoch {}".format(
                    epoch + 1, acc_metric, auc_result, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_accuracy", acc_metric, epoch + 1)
            writer.add_scalar("val_auc", auc_result, epoch + 1)

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

The error I get when I attempt to run it:

----------
epoch 1/5
torch.Size([2, 1024, 3, 3, 3])
metatensor([[ 0.1698, -0.0128,  0.1776],
        [ 0.2385,  0.0806,  0.1863]], device='cuda:0',
       grad_fn=<AliasBackward0>) torch.Size([2, 3])
metatensor([[-0.1847, -0.0188],
        [-0.2308, -0.0320]], device='cuda:0', grad_fn=<AliasBackward0>) torch.Size([2, 2])
/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in MeanBackward1. Traceback of forward call that caused the error:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 736, in start
    self.io_loop.start()
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
    self._run_once()
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
    handle._run()
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue
    await self.process_one()
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 505, in process_one
    await dispatch(*args)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell
    await result
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 740, in execute_request
    reply_content = await reply_content
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
    res = shell.run_cell(
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 546, in run_cell
    return super().run_cell(*args, **kwargs)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3009, in run_cell
    result = self._run_cell(
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3064, in _run_cell
    result = runner(coro)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
    coro.send(None)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_3909226/900390216.py", line 29, in <module>
    out_class, out_contrast = model(inputs)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/tmp/ipykernel_3909226/1147907885.py", line 32, in forward
    out_class = self.class_layers(features)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/pooling.py", line 1235, in forward
    return F.adaptive_avg_pool3d(input, self.output_size)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/functional.py", line 1229, in adaptive_avg_pool3d
    return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/overrides.py", line 1551, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/monai/data/meta_tensor.py", line 276, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/_tensor.py", line 1295, in __torch_function__
    ret = func(*args, **kwargs)
  File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/functional.py", line 1231, in adaptive_avg_pool3d
    return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
 (Triggered internally at /opt/conda/conda-bld/pytorch_1682343970094/work/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[351], line 33
     31 losscl = lossfn(out_class, class_le)
     32 lossca = lossfn2(out_contrast, contrast_le)
---> 33 losscl.backward()
     34 lossca.backward()
     35 optimizer.step()

File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/_tensor.py:478, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    431 r"""Computes the gradient of current tensor w.r.t. graph leaves.
    432 
    433 The graph is differentiated using the chain rule. If the tensor is
   (...)
    475         used to compute the attr::tensors.
    476 """
    477 if has_torch_function_unary(self):
--> 478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
    481         self,
    482         gradient=gradient,
    483         retain_graph=retain_graph,
    484         create_graph=create_graph,
    485         inputs=inputs,
    486     )
    487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/overrides.py:1551, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
   1545     warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
   1546                   "will be an error in future, please define it as a classmethod.",
   1547                   DeprecationWarning)
   1549 # Use `public_api` instead of `implementation` so __torch_function__
   1550 # implementations can do equality/identity comparisons.
-> 1551 result = torch_func_method(public_api, types, args, kwargs)
   1553 if result is not NotImplemented:
   1554     return result

File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/monai/data/meta_tensor.py:276, in MetaTensor.__torch_function__(cls, func, types, args, kwargs)
    274 if kwargs is None:
    275     kwargs = {}
--> 276 ret = super().__torch_function__(func, types, args, kwargs)
    277 # if `out` has been used as argument, metadata is not copied, nothing to do.
    278 # if "out" in kwargs:
    279 #     return ret
    280 if _not_requiring_metadata(ret):

File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/_tensor.py:1295, in Tensor.__torch_function__(cls, func, types, args, kwargs)
   1292     return NotImplemented
   1294 with _C.DisableTorchFunctionSubclass():
-> 1295     ret = func(*args, **kwargs)
   1296     if func in get_default_nowrap_functions():
   1297         return ret

File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    477 if has_torch_function_unary(self):
    478     return handle_torch_function(
    479         Tensor.backward,
    480         (self,),
   (...)
    485         inputs=inputs,
    486     )
--> 487 torch.autograd.backward(
    488     self, gradient, retain_graph, create_graph, inputs=inputs
    489 )

File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    195     retain_graph = create_graph
    197 # The reason we repeat same the comment below is that
    198 # some Python versions print out the first line of a multi-line function
    199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    201     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    202     allow_unreachable=True, accumulate_grad=True)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 1024, 3, 3, 3]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I know this is a huge dump of information. But I'm a bit clueless as to where to go from here. I haven't found any operations that were in place. The only thing that somewhat sticks out from the output is that there is a mention of the adaptive pool 3d function from torch. Notably, if I do convert this network back to having just one output (i.e. get rid of the out_contrast head in the definition of the custom densenet), it works absolutely fine. So there's definitely something about the multiple output nature of this network that is messing up the gradient computation.

Alternatively, if there is a way to implement having two classification heads in a CNN, it'd also be great to know about that.

KumoLiu commented 5 months ago

Hi @abhisuri97, did you try inplace=False in the act layer? Thanks.

(Convert to the discussion, not related to the functionality in the core)