YzzLiu / FracSegNet

A deep-learning method for segmenting fractures from 3D CT images.
https://link.springer.com/chapter/10.1007/978-3-031-43996-4_30
MIT License
20 stars 3 forks source link

bug about function MultipleOutputLoss2 #6

Open ohcccc opened 2 months ago

ohcccc commented 2 months ago

Hi, thanks to your great work.I'm having some issues running your code. Can you give me some help?@YzzLiu in function get_tp_fp_fn_tn() ‘’‘ def get_tp_fp_fn_tn(net_output, gt, disMap = None, axes=None, mask=None, square=False ,current_epoch = None): smooth_trans = True Tauo_st = 0 st_epoch = 1000

if smooth_trans == False:
    if disMap != None:
        disMap = disMap / torch.mean(disMap)
        temp_disMap_value = disMap
if smooth_trans == True:
    if disMap != None:
        disMap = disMap / torch.mean(disMap)
        if current_epoch <Tauo_st:
            temp_disMap_value = torch.ones_like(disMap)
        elif Tauo_st <= current_epoch < Tauo_st + st_epoch :
            warm_start_matrix = torch.ones_like(disMap)
            warm_para = float(Tauo_st + st_epoch - current_epoch) / st_epoch
            temp_disMap_value = warm_para * warm_start_matrix + (1 - warm_para) * disMap
        elif current_epoch >= Tauo_st + st_epoch:
            temp_disMap_value = disMap

disMap = temp_disMap_value

if axes is None:
    axes = tuple(range(2, len(net_output.size())))

shp_x = net_output.shape
shp_y = gt.shape
shp_disMap = disMap.shape

with torch.no_grad():
    if len(shp_x) != len(shp_y):
        gt = gt.view((shp_y[0], 1, *shp_y[1:]))
        disMap = disMap.view((shp_disMap[0], 1, *shp_disMap[1:]))
    if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
        # if this is the case then gt is probably already a one hot encoding
        y_onehot = gt
    else:
        gt = gt.long()
        y_onehot = torch.zeros(shp_x)
        if net_output.device.type == "cuda":
            y_onehot = y_onehot.cuda(net_output.device.index)
        y_onehot.scatter_(1, gt, 1)
disMap2onehot = disMap.repeat_interleave(shp_x[1],dim = 1)

disMap2onehot = disMap2onehot.cuda(net_output.device.index)

tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fn = (1 - net_output) * y_onehot
tn = (1 - net_output) * (1 - y_onehot)

tp = torch.mul(tp, disMap2onehot)
fp = torch.mul(fp, disMap2onehot)
fn = torch.mul(fn, disMap2onehot)
tn = torch.mul(tn, disMap2onehot)

if mask is not None:
    tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
    fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
    fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
    tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1)

if square:
    tp = tp ** 2
    fp = fp ** 2
    fn = fn ** 2
    tn = tn ** 2

if len(axes) > 0:
    tp = sum_tensor(tp, axes, keepdim=False)
    fp = sum_tensor(fp, axes, keepdim=False)
    fn = sum_tensor(fn, axes, keepdim=False)
    tn = sum_tensor(tn, axes, keepdim=False)

return tp, fp, fn, tn

‘’‘ tp = torch.mul(tp, disMap2onehot) there is a bug:The size of tensor a (128) must match the size of tensor b (218) at non-singleton dimension 4. i add ' disMap2onehot = torch.nn.functional.interpolate(disMap2onehot, size=tp.shape[2:])' I don't know if it's the right thing to do.

ohcccc commented 2 months ago

in addition,there are bug in function RobustCrossEntropyLoss(). """ 发生异常: UnboundLocalError (note: full exception trace is shown but execution is paused at: _run_module_as_main) cannot access local variable 'temp_after_DisMap' where it is not associated with a value File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/loss_functions/crossentropy.py", line 53, in forward return temp_after_DisMap ^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/loss_functions/dice_loss.py", line 378, in forward

ce_loss = self.ce(net_output, target[:, 0].long(),disMap_weight[:,0],epoch)

          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/loss_functions/deep_supervision.py", line 29, in forward l = weights[0] self.loss(x[0], y[0],disMap[0],epoch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/network_training/nnUNetTrainerV2.py", line 252, in run_iteration l = self.loss(output, target, disMap, self.epoch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/network_training/network_trainer.py", line 462, in run_training l = self.run_iteration(self.tr_gen, True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/network_training/nnUNetTrainer.py", line 321, in run_training super(nnUNetTrainer, self).run_training() File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/network_training/nnUNetTrainerV2.py", line 446, in run_training ret = super().run_training() ^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/code/FracSegNet/code/Training/fracSegNet/run/run_training.py", line 179, in main trainer.run_training() File "/home/icml/code/FracSegNet/code/Training/fracSegNet/run/run_training.py", line 199, in main() File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/runpy.py", line 88, in _run_code exec(code, run_globals) File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/runpy.py", line 198, in _run_module_as_main (Current frame) return _run_code(code, main_globals, None, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ UnboundLocalError: cannot access local variable 'temp_after_DisMap' where it is not associated with a value """

YzzLiu commented 2 months ago

disMap2onehot

Hello, thank you for running the FracSegNet code. I didn't encounter the issue you mentioned while running the code. Could you please provide the shape or size of the variable disMap2onehot so that I can help identify the problem?

YzzLiu commented 2 months ago

in addition,there are bug in function RobustCrossEntropyLoss(). """ 发生异常: UnboundLocalError (note: full exception trace is shown but execution is paused at: _run_module_as_main) cannot access local variable 'temp_after_DisMap' where it is not associated with a value File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/loss_functions/crossentropy.py", line 53, in forward return temp_after_DisMap ^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/loss_functions/dice_loss.py", line 378, in forward # ce_loss = self.ce(net_output, target[:, 0].long(),disMap_weight[:,0],epoch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/loss_functions/deep_supervision.py", line 29, in forward l = weights[0] self.loss(x[0], y[0],disMap[0],epoch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/network_training/nnUNetTrainerV2.py", line 252, in run_iteration l = self.loss(output, target, disMap, self.epoch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/network_training/network_trainer.py", line 462, in run_training l = self.run_iteration(self.tr_gen, True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/network_training/nnUNetTrainer.py", line 321, in run_training super(nnUNetTrainer, self).run_training() File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/site-packages/nnunet/training/network_training/nnUNetTrainerV2.py", line 446, in run_training ret = super().run_training() ^^^^^^^^^^^^^^^^^^^^^^ File "/home/icml/code/FracSegNet/code/Training/fracSegNet/run/run_training.py", line 179, in main trainer.run_training() File "/home/icml/code/FracSegNet/code/Training/fracSegNet/run/run_training.py", line 199, in main() File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/runpy.py", line 88, in _run_code exec(code, run_globals) File "/home/icml/miniconda3/envs/nnunet/lib/python3.12/runpy.py", line 198, in _run_module_as_main (Current frame) return _run_code(code, main_globals, None, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ UnboundLocalError: cannot access local variable 'temp_after_DisMap' where it is not associated with a value """

Thank you for bringing up this issue. When modifying the RobustCrossEntropyLoss function, I changed the reduction parameter in nn.CrossEntropyLoss to none for the calculation of the loss. You can find and modify this parameter in torch.nn.modules.loss.py. Your issue might be caused by the mentioned reason, and I apologize for any inconvenience.

ohcccc commented 1 month ago

disMap2onehot

Hello, thank you for running the FracSegNet code. I didn't encounter the issue you mentioned while running the code. Could you please provide the shape or size of the variable disMap2onehot so that I can help identify the problem?

thank you for your reply. disMap2onehot shape is (8,4,218,218,218), tp shape is (8,4,128,128,128).i don`t konw where the mistake lies.

ohcccc commented 3 weeks ago

@YzzLiu