ljwztc / CLIP-Driven-Universal-Model

[ICCV 2023] CLIP-Driven Universal Model; Rank first in MSD Competition.
Other
521 stars 58 forks source link

Diceloss does not decrease during training, and Dice is all 0 during validation #56

Closed sharonlee12 closed 1 month ago

sharonlee12 commented 6 months ago

Hello!I am training on BTCV,1000epochs, where diceloss oscillates continuously without decreasing, while celoss decreases. When I use my checkpoint to validate, the result is as follows: Spleen: dice 0.0000, recall 0.0000, precision nan Right Kidney: dice 0.0000, recall 0.0000, precision nan Left Kidney: dice 0.0000, recall 0.0000, precision nan Esophagus: dice 0.0000, recall 0.0000, precision nan Liver: dice 0.0000, recall 0.0000, precision nan Stomach: dice 0.0000, recall 0.0000, precision nan Aorta: dice 0.0000, recall 0.0000, precision nan Postcava: dice 0.0000, recall 0.0000, precision nan Portal Vein and Splenic Vein: dice 0.0000, recall 0.0000, precision nan Pancreas: dice 0.0000, recall 0.0000, precision nan Right Adrenal Gland: dice 0.0000, recall 0.0000, precision nan Left Adrenal Gland: dice 0.0000, recall 0.0000, precision nan case01_Multi-Atlas_Labeling/label/label0035| Spleen: 0.0000, Right Kidney: 0.0000, Left Kidney: 0.0000, Eso phagus: 0.0000, Liver: 0.0000, Stomach: 0.0000, Aorta: 0.0000, Postcava: 0.0000, Portal Vein and Splenic Ve in: 0.0000, Pancreas: 0.0000, Right Adrenal Gland: 0.0000, Left Adrenal Gland: 0.0000, Have you ever encountered a similar problem? I hope to receive your reply, thank you! Here is codes for training: `def train(args, train_loader, model, optimizer, loss_seg_DICE, loss_seg_CE): model.train() loss_bce_ave = 0 loss_dice_ave = 0 epoch_iterator = tqdm( train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True ) for step, batch in enumerate(epoch_iterator): x, y, name = batch["image"].to(args.device), batch["post_label"].float().to(args.device), batch['name'] torch.cuda.empty_cache() with torch.cuda.amp.autocast(): logit_map = model(x) torch.cuda.empty_cache()

    term_seg_Dice = loss_seg_DICE.forward(logit_map, y, name, TEMPLATE)
    term_seg_BCE = loss_seg_CE.forward(logit_map, y, name, TEMPLATE)
    loss = term_seg_BCE + term_seg_Dice
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    epoch_iterator.set_description(
        "Epoch=%d: Training (%d / %d Steps) (dice_loss=%2.5f, bce_loss=%2.5f)" % (
            args.epoch, step, len(train_loader), term_seg_Dice.item(), term_seg_BCE.item())
    )
    loss_bce_ave += term_seg_BCE.item()
    loss_dice_ave += term_seg_Dice.item()
    torch.cuda.empty_cache()
print('Epoch=%d: ave_dice_loss=%2.5f, ave_bce_loss=%2.5f' % (args.epoch, loss_dice_ave/len(epoch_iterator), loss_bce_ave/len(epoch_iterator)))

return loss_dice_ave/len(epoch_iterator), loss_bce_ave/len(epoch_iterator)

`

sharonlee12 commented 6 months ago

Here is the picture of training loss: image image

zjy399 commented 6 months ago

Have you solved it?

sharonlee12 commented 6 months ago

Have you solved it?

Not yet.

sharonlee12 commented 6 months ago

Have you solved it?

IF you solved it,can you let me know?So will me

ljwztc commented 6 months ago

I'm not sure whether with torch.cuda.amp.autocast(): and torch.cuda.empty_cache() in your code would effect the gradient back propagation.

sharonlee12 commented 6 months ago

I'm not sure whether with torch.cuda.amp.autocast(): and torch.cuda.empty_cache() in your code would effect the gradient back propagation.

May I ask if you have done any additional data preprocessing? I only run the label_ transfer.py

sharonlee12 commented 6 months ago

I'm not sure whether with torch.cuda.amp.autocast(): and torch.cuda.empty_cache() in your code would effect the gradient back propagation.

hello!I have deleted the with torch.cuda.amp.autocast(): and torch.cuda.empty_cache(),At present, dice loss and bce loss can be reduced during training, but during the test, the results are still all 0: Liver: dice 0.0000, recall 0.0000, precision 0.0000. Liver Tumor: dice 0.0000, recall 0.0000, precision nan. case04_LiTS/label/liver_1| Liver: 0.0000, Liver Tumor: 0.0000, Case04_LITS /label/liver_1| liver: 0.0000, liver tumor: 0.0000, At the beginning (epoch 50), the result is Liver: dice 0.1749, recall 0.4666, precision 0.1077. Liver Tumor: dice 0.0032, recall 1.0000, precision 0.0016.

Adoreeeeee commented 5 months ago

I have the same problem.Hope u will solve it.

ljwztc commented 1 month ago

The bug in the dice loss calculation has been addressed at this link. Consequently, we can now observe an expected decrease in the dice loss.