Hello, I am attempting to modify DenseNet121 to have multiple classification heads. However, when I'm computing the loss function, I get an error message saying that one of the variables has been modified by an inplace operation. The error and my initial thoughts/attempts to address the issue are at the bottom.
Here is my code:
Minorly modified DenseNet definition
class DenseNet(nn.Module):
"""
Densenet based on: `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993.pdf>`_.
Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.
This network is non-deterministic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below
for more details:
https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
Args:
spatial_dims: number of spatial dimensions of the input image.
in_channels: number of the input channel.
out_channels: number of the output classes.
init_features: number of filters in the first convolution layer.
growth_rate: how many filters to add each layer (k in paper).
block_config: how many layers in each pooling block.
bn_size: multiplicative factor for number of bottle neck layers.
(i.e. bn_size * k features in the bottleneck layer)
act: activation type and arguments. Defaults to relu.
norm: feature normalization type and arguments. Defaults to batch norm.
dropout_prob: dropout rate after each dense layer.
"""
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 24, 16),
bn_size: int = 4,
act: str | tuple = ("relu", {"inplace": True}),
norm: str | tuple = "batch",
dropout_prob: float = 0.0,
) -> None:
super().__init__()
conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
avg_pool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
Pool.ADAPTIVEAVG, spatial_dims
]
self.features = nn.Sequential(
OrderedDict(
[
("conv0", conv_type(in_channels, init_features, kernel_size=7, stride=2, padding=3, bias=False)),
("norm0", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=init_features)),
("relu0", get_act_layer(name=act)),
("pool0", pool_type(kernel_size=3, stride=2, padding=1)),
]
)
)
in_channels = init_features
for i, num_layers in enumerate(block_config):
block = _DenseBlock(
spatial_dims=spatial_dims,
layers=num_layers,
in_channels=in_channels,
bn_size=bn_size,
growth_rate=growth_rate,
dropout_prob=dropout_prob,
act=act,
norm=norm,
)
self.features.add_module(f"denseblock{i + 1}", block)
in_channels += num_layers * growth_rate
if i == len(block_config) - 1:
self.features.add_module(
"norm5", get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels)
)
else:
_out_channels = in_channels // 2
trans = _Transition(
spatial_dims, in_channels=in_channels, out_channels=_out_channels, act=act, norm=norm
)
self.features.add_module(f"transition{i + 1}", trans)
in_channels = _out_channels
# pooling and classification
self.class_layers = nn.Sequential(
OrderedDict(
[
("relu", get_act_layer(name=act)),
("pool", avg_pool_type(1)),
("flatten", nn.Flatten(1)),
("out", nn.Linear(in_channels, out_channels)),
]
)
)
print(self.class_layers)
for m in self.modules():
if isinstance(m, conv_type):
nn.init.kaiming_normal_(torch.as_tensor(m.weight))
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
nn.init.constant_(torch.as_tensor(m.weight), 1)
nn.init.constant_(torch.as_tensor(m.bias), 0)
elif isinstance(m, nn.Linear):
nn.init.constant_(torch.as_tensor(m.bias), 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
xp = self.features(x)
# x = self.class_layers(xp)
return xp
Custom DenseNet121 with two classification heads (one that outputs a class and another that outputs contrast status)
class CustomDenseNet121(DenseNet):
def __init__(self, spatial_dims=3, in_channels=2, out_channels=3, contrast_classes=2):
# Initialize the standard DenseNet121 with the usual parameters
super().__init__(spatial_dims=spatial_dims, in_channels=in_channels,
out_channels=out_channels,
init_features=64, growth_rate=32, block_config=(6,12,24,16))
conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
avg_pool_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
Pool.ADAPTIVEAVG, spatial_dims
]
# Additional head for contrast status prediction
self.contrast_head = nn.Sequential(
OrderedDict(
[
("relu", get_act_layer(name=("relu", {"inplace": True}))),
("pool", avg_pool_type(1)),
("flatten", nn.Flatten(1)),
("out", nn.Linear(1024, 2)),
]
)
)
print(self.contrast_head)
def forward(self, x):
# Extract features from the DenseNet backbone
features = super().forward(x)
# Apply adaptive pooling to reshape output for linear layers
# Reshape pooled output to (batch_size, num_features)
# Pass pooled features through the classification and contrast heads
out_class = self.class_layers(features)
out_contrast = self.contrast_head(features)
return out_class, out_contrast
Training loop
val_interval = 1
best_metric = -1
best_metric_epoch = -1
patience = 0
max_patience = 20
writer = SummaryWriter(f"model-output/temprun")
epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
model.cuda()
lossfn = torch.nn.CrossEntropyLoss()
lossfn2 = torch.nn.CrossEntropyLoss()
for epoch in range(epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, class_l, contrast_l = batch_data["img"].to(device), batch_data["label_y"], batch_data["contrast_present"]
class_lt = class_l.type(torch.LongTensor)
class_le = class_lt.to(device)
contrast_lt = contrast_l.type(torch.LongTensor)
contrast_le = contrast_lt.to(device)
optimizer.zero_grad()
out_class, out_contrast = model(inputs)
# loss = multi_task_loss(out_class, class_l, out_contrast, contrast_l)
losscl = lossfn(out_class, class_le)
lossca = lossfn2(out_contrast, contrast_le)
losscl.backward()
lossca.backward()
optimizer.step()
print('got here')
# epoch_loss += loss.item()
epoch_len = len(train_ds) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
epoch_loss /= step
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
y_pred = torch.tensor([], dtype=torch.float32, device=device)
y_pred_contrast = torch.tensor([], dtype=torch.float32, device=device)
y = torch.tensor([], dtype=torch.long, device=device)
for val_data in val_loader:
val_images, val_labels, val_contrast = val_data["img"].to(device), val_data["label_y"].to(device), val_data["contrast_present"].to(device)
y_pred = torch.cat([y_pred, model(val_images)[0]], dim=0)
y = torch.cat([y, val_labels], dim=0)
acc_value = torch.eq(y_pred.argmax(dim=1), y)
acc_metric = acc_value.sum().item() / len(acc_value)
y_onehot = [post_label(i) for i in decollate_batch(y, detach=False)]
y_pred_act = [post_pred(i) for i in decollate_batch(y_pred, detach=False)]
auc_result = monai.metrics.compute_roc_auc(torch.stack(y_pred_act), torch.stack(y_onehot),
average="micro")
if auc_result > best_metric:
best_metric = auc_result
best_metric_epoch = epoch + 1
patience = 0
torch.save(model.state_dict(), f"SIIM-output/{basedir}/best_metric_model_classification3d_dict.pth")
print("saved new best metric model")
else:
patience += 1
torch.save(model.state_dict(), f"SIIM-output/{basedir}/epoch-{epoch+1}_metric_model_classification3d_dict.pth")
if patience > max_patience:
print("NO IMPROVEMENT")
break
print(
"current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best AUC: {:.4f} at epoch {}".format(
epoch + 1, acc_metric, auc_result, best_metric, best_metric_epoch
)
)
writer.add_scalar("val_accuracy", acc_metric, epoch + 1)
writer.add_scalar("val_auc", auc_result, epoch + 1)
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()
The error I get when I attempt to run it:
----------
epoch 1/5
torch.Size([2, 1024, 3, 3, 3])
metatensor([[ 0.1698, -0.0128, 0.1776],
[ 0.2385, 0.0806, 0.1863]], device='cuda:0',
grad_fn=<AliasBackward0>) torch.Size([2, 3])
metatensor([[-0.1847, -0.0188],
[-0.2308, -0.0320]], device='cuda:0', grad_fn=<AliasBackward0>) torch.Size([2, 2])
/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/autograd/__init__.py:200: UserWarning: Error detected in MeanBackward1. Traceback of forward call that caused the error:
File "<frozen runpy>", line 198, in _run_module_as_main
File "<frozen runpy>", line 88, in _run_code
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel_launcher.py", line 17, in <module>
app.launch_new_instance()
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/traitlets/config/application.py", line 1043, in launch_instance
app.start()
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 736, in start
self.io_loop.start()
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 195, in start
self.asyncio_loop.run_forever()
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
self._run_once()
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
handle._run()
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/asyncio/events.py", line 80, in _run
self._context.run(self._callback, *self._args)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 516, in dispatch_queue
await self.process_one()
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 505, in process_one
await dispatch(*args)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 412, in dispatch_shell
await result
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 740, in execute_request
reply_content = await reply_content
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 422, in do_execute
res = shell.run_cell(
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 546, in run_cell
return super().run_cell(*args, **kwargs)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3009, in run_cell
result = self._run_cell(
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3064, in _run_cell
result = runner(coro)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
coro.send(None)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3269, in run_cell_async
has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
if await self.run_code(code, result, async_=asy):
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/tmp/ipykernel_3909226/900390216.py", line 29, in <module>
out_class, out_contrast = model(inputs)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/tmp/ipykernel_3909226/1147907885.py", line 32, in forward
out_class = self.class_layers(features)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/modules/pooling.py", line 1235, in forward
return F.adaptive_avg_pool3d(input, self.output_size)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/functional.py", line 1229, in adaptive_avg_pool3d
return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/overrides.py", line 1551, in handle_torch_function
result = torch_func_method(public_api, types, args, kwargs)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/monai/data/meta_tensor.py", line 276, in __torch_function__
ret = super().__torch_function__(func, types, args, kwargs)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/_tensor.py", line 1295, in __torch_function__
ret = func(*args, **kwargs)
File "/gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/nn/functional.py", line 1231, in adaptive_avg_pool3d
return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
(Triggered internally at /opt/conda/conda-bld/pytorch_1682343970094/work/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[351], line 33
31 losscl = lossfn(out_class, class_le)
32 lossca = lossfn2(out_contrast, contrast_le)
---> 33 losscl.backward()
34 lossca.backward()
35 optimizer.step()
File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/_tensor.py:478, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
431 r"""Computes the gradient of current tensor w.r.t. graph leaves.
432
433 The graph is differentiated using the chain rule. If the tensor is
(...)
475 used to compute the attr::tensors.
476 """
477 if has_torch_function_unary(self):
--> 478 return handle_torch_function(
479 Tensor.backward,
480 (self,),
481 self,
482 gradient=gradient,
483 retain_graph=retain_graph,
484 create_graph=create_graph,
485 inputs=inputs,
486 )
487 torch.autograd.backward(
488 self, gradient, retain_graph, create_graph, inputs=inputs
489 )
File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/overrides.py:1551, in handle_torch_function(public_api, relevant_args, *args, **kwargs)
1545 warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
1546 "will be an error in future, please define it as a classmethod.",
1547 DeprecationWarning)
1549 # Use `public_api` instead of `implementation` so __torch_function__
1550 # implementations can do equality/identity comparisons.
-> 1551 result = torch_func_method(public_api, types, args, kwargs)
1553 if result is not NotImplemented:
1554 return result
File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/monai/data/meta_tensor.py:276, in MetaTensor.__torch_function__(cls, func, types, args, kwargs)
274 if kwargs is None:
275 kwargs = {}
--> 276 ret = super().__torch_function__(func, types, args, kwargs)
277 # if `out` has been used as argument, metadata is not copied, nothing to do.
278 # if "out" in kwargs:
279 # return ret
280 if _not_requiring_metadata(ret):
File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/_tensor.py:1295, in Tensor.__torch_function__(cls, func, types, args, kwargs)
1292 return NotImplemented
1294 with _C.DisableTorchFunctionSubclass():
-> 1295 ret = func(*args, **kwargs)
1296 if func in get_default_nowrap_functions():
1297 return ret
File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/_tensor.py:487, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
477 if has_torch_function_unary(self):
478 return handle_torch_function(
479 Tensor.backward,
480 (self,),
(...)
485 inputs=inputs,
486 )
--> 487 torch.autograd.backward(
488 self, gradient, retain_graph, create_graph, inputs=inputs
489 )
File /gpfs/gsfs12/users/suria2/conda/envs/abhi/lib/python3.11/site-packages/torch/autograd/__init__.py:200, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
195 retain_graph = create_graph
197 # The reason we repeat same the comment below is that
198 # some Python versions print out the first line of a multi-line function
199 # calls in the traceback and some print out the last line
--> 200 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
201 tensors, grad_tensors_, retain_graph, create_graph, inputs,
202 allow_unreachable=True, accumulate_grad=True)
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2, 1024, 3, 3, 3]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
I know this is a huge dump of information. But I'm a bit clueless as to where to go from here. I haven't found any operations that were in place. The only thing that somewhat sticks out from the output is that there is a mention of the adaptive pool 3d function from torch. Notably, if I do convert this network back to having just one output (i.e. get rid of the out_contrast head in the definition of the custom densenet), it works absolutely fine. So there's definitely something about the multiple output nature of this network that is messing up the gradient computation.
Alternatively, if there is a way to implement having two classification heads in a CNN, it'd also be great to know about that.
Hello, I am attempting to modify DenseNet121 to have multiple classification heads. However, when I'm computing the loss function, I get an error message saying that one of the variables has been modified by an inplace operation. The error and my initial thoughts/attempts to address the issue are at the bottom.
Here is my code:
Minorly modified DenseNet definition
Custom DenseNet121 with two classification heads (one that outputs a class and another that outputs contrast status)
Training loop
The error I get when I attempt to run it:
I know this is a huge dump of information. But I'm a bit clueless as to where to go from here. I haven't found any operations that were in place. The only thing that somewhat sticks out from the output is that there is a mention of the adaptive pool 3d function from torch. Notably, if I do convert this network back to having just one output (i.e. get rid of the out_contrast head in the definition of the custom densenet), it works absolutely fine. So there's definitely something about the multiple output nature of this network that is messing up the gradient computation.
Alternatively, if there is a way to implement having two classification heads in a CNN, it'd also be great to know about that.