bowang-lab / MedSAM

Segment Anything in Medical Images
https://www.nature.com/articles/s41467-024-44824-z
Apache License 2.0
3.03k stars 419 forks source link

测试loss异常:How to solve the problem that the verification Loss does not decrease? #337

Open lxxq412 opened 1 week ago

lxxq412 commented 1 week ago

我们建立了自己的图像分割数据集,该数据集有900多张图像。增加验证代码如下:

    for epoch in range(start_epoch, num_epochs):
        # 训练模式
        medsam_model.train()
        train_epoch_loss = 0
        for step, (image, gt2D, boxes, _) in enumerate(tqdm(train_dataloader)):
            optimizer.zero_grad() 
            boxes_np = boxes.detach().cpu().numpy()
            image, gt2D = image.to(device), gt2D.to(device)

            if args.use_amp:
                ## AMP:默认不启动
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    medsam_pred = medsam_model(image, boxes_np)
                    loss = seg_loss(medsam_pred, gt2D) + ce_loss(
                        medsam_pred, gt2D.float()
                    )
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            else:
                medsam_pred = medsam_model(image, boxes_np)
                loss = seg_loss(medsam_pred, gt2D) + ce_loss(medsam_pred, gt2D.float())
                loss.backward()
                optimizer.step() 
                optimizer.zero_grad()

            train_epoch_loss += loss.item()
            iter_num += 1

        train_epoch_loss /= step # 当前train_epoch_loss的平均损失
        train_losses.append(train_epoch_loss)

        # 测试模式
        medsam_model.eval() # 设置模型为评估模式
        test_epoch_loss = 0
        with torch.no_grad():
            for step, (image, gt2D, boxes, _) in enumerate(test_dataloader):
                image, gt2D = image.to(device), gt2D.to(device)
                boxes_np = boxes.detach().cpu().numpy()
                medsam_pred = medsam_model(image, boxes_np)
                loss = seg_loss(medsam_pred, gt2D) + ce_loss(medsam_pred, gt2D.float())
                test_epoch_loss += loss.item()
        test_epoch_loss /= step
        test_losses.append(test_epoch_loss)

但是我们训练出来的loss图存在问题,测试的Loss不下降,非常奇怪。有大神指导是什么问题吗?

1

我们把学习率调低到1e-7,loss图过于光滑,也存在问题? 2

大神们指导如何调整参数,优化Loss吗?

JunMa11 commented 1 week ago

how about using Dice score to monitor the validation performance?