MrGiovanni / ContinualLearning

[MICCAI 2023] Continual Learning for Abdominal Multi-Organ and Tumor Segmentation
https://www.cs.jhu.edu/~alanlab/Pubs23/zhang2023continual.pdf
Other
57 stars 8 forks source link

questions about training on btcv dataset #10

Open sharonlee12 opened 1 year ago

sharonlee12 commented 1 year ago

Hello, I'm sorry to bother you. I'm a novice in medical image segmentation and encountered the following problem while trying to train the btcv dataset using your code:

  1. First, I use label transfer. py to generate post label.h5 file, as the goal is to split the btcv dataset, I have made the following modifications, and the other settings remain unchanged: ORGAN_LIST = '../dataset/dataset_list/btcv_train.txt' NUM_WORKER = 8 NUM_CLASS = 13 TRANSFER_LIST = ['01'] TEMPLATE={ '01': [1,2,3,4,5,6,7,8,9,10,11,12,13], 2.Secondly, I attempted to use train.py to train the btcv dataset, and my settings are as follows: CUDA_VISIBLE_DEVICES=5 python train.py --phase train --data_root_path /data/ContinualLearning/dataset/ --train_data_txt_path /data/ContinualLearning/dataset/dataset_list/mybtcv_train.txt --val_data_txt_path /data/ContinualLearning/dataset/dataset_list/mybtcv_val.txt --organ_list 1 2 3 4 5 6 7 8 9 10 11 12 13 --max_epoch 1001 --warmup_epoch 15 --batch_size 1 --num_samples 1 --lr 1e-4 --model swinunetr --trans_encoding word_embedding --out_nonlinear softmax --out_channels 14 --log_name /data/ContinualLearning/log1018/ --pretrain /data/ContinualLearning/pretrained_weights/swin_unetr.base_5000ep_f48_lr2e-4_pretrained.pt At the same time, I made the following modifications to the num_classes of the two loss functions in train: `class DiceLoss(nn.Module): def init(self, weight=None, ignore_index=None, num_classes=14, kwargs): super(DiceLoss, self).init() self.kwargs = kwargs self.weight = weight self.ignore_index = ignore_index self.num_classes = num_classes self.dice = loss.BinaryDiceLoss(self.kwargs)

    def forward(self, predict, target, organ_list): total_loss = [] predict = F.sigmoid(predict)

    total_loss = []
    B = predict.shape[0]
    
    for b in range(B):
        for organ in organ_list:
            dice_loss = self.dice(predict[b, organ-1], target[b, organ-1])
            total_loss.append(dice_loss)
    
    total_loss = torch.stack(total_loss)
    
    return total_loss.sum()/total_loss.shape[0]

class Multi_BCELoss(nn.Module): def init(self, ignore_index=None, num_classes=14, **kwargs): super(Multi_BCELoss, self).init() self.kwargs = kwargs self.num_classes = num_classes self.ignore_index = ignore_index self.criterion = nn.BCEWithLogitsLoss()

def forward(self, predict, target, organ_list):
    assert predict.shape[2:] == target.shape[2:], 'predict & target shape do not match'
    total_loss = []
    B = predict.shape[0]

    for b in range(B):
        for organ in organ_list:
            ce_loss = self.criterion(predict[b, organ-1], target[b, organ-1])
            total_loss.append(ce_loss)
    total_loss = torch.stack(total_loss)

    return total_loss.sum()/total_loss.shape[0]`

3.The above are my modifications, but during the training process, the dice loss did not decrease and remained around 0.97. I speculate that there may be a problem with the category settings in my modifications.May I ask if there is a problem with my modifications? I'm sorry to bother you, and I hope to receive guidance from you or someone else who is interested in the code. Thank you!

longhainguyen commented 10 months ago

Could you please provide guidance on generating the postlabel.h5 file by use label transfer.py ?