Open IcecreamArtist opened 10 months ago
Hi,
In the code:
loss1 = 0.5 * (ce_loss(outputs1[:args.labeled_bs], label_batch[:][:args.labeled_bs].long()) + dice_loss( outputs_soft1[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))) loss2 = 0.5 * (ce_loss(outputs2[:args.labeled_bs], label_batch[:][:args.labeled_bs].long()) + dice_loss( outputs_soft2[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1)))
for ce_loss, it fetch 'label_batch[:][:args.labeled_bs]' which can be simplified to 'label_batch[:args.labeled_bs]'.
Welcome discussion if there is any mistask.
Hi,
In the code:
for ce_loss, it fetch 'label_batch[:][:args.labeled_bs]' which can be simplified to 'label_batch[:args.labeled_bs]'.
Welcome discussion if there is any mistask.