uni-medical / SAM-Med3D

SAM-Med3D: An Efficient General-purpose Promptable Segmentation Model for 3D Volumetric Medical Image
Apache License 2.0
531 stars 71 forks source link

fix up always dice=0 and find a new bug #44

Closed adrianzzk closed 8 months ago

adrianzzk commented 8 months ago

之前一直被dice为0这个bug所困扰,所以下定决心给它修复了,同时在修bug的时候我发现,在打印每个epoch的平均loss值时,step的个数少1,导致每一轮打印出来的loss值偏高。不知道我找到的这个bug是不是正确的,还请作者团队看一下 我还不是很擅长使用git,我就把修改的代码赋在下面了

    def train_epoch(self, epoch, num_clicks):
        epoch_loss = 0
        epoch_iou = 0
        epoch_dice = 0
        self.model.train()
        if self.args.multi_gpu:
            sam_model = self.model.module
        else:
            sam_model = self.model
            self.args.rank = -1

        if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0):
            tbar = tqdm(self.dataloaders)
        else:
            tbar = self.dataloaders

        self.optimizer.zero_grad()
        step_loss = 0
        epoch_dice = 0      #change
        for step, (image3D, gt3D) in enumerate(tbar):

            my_context = self.model.no_sync if self.args.rank != -1 and step % self.args.accumulation_steps != 0 else nullcontext

            with my_context():

                image3D = self.norm_transform(image3D.squeeze(dim=1)) # (N, C, W, H, D)
                image3D = image3D.unsqueeze(dim=1)

                image3D = image3D.to(device)
                gt3D = gt3D.to(device).type(torch.long)
                with amp.autocast():
                    image_embedding = sam_model.image_encoder(image3D)

                    self.click_points = []
                    self.click_labels = []

                    pred_list = []

                    prev_masks, loss = self.interaction(sam_model, image_embedding, gt3D, num_clicks=11)                

                epoch_loss += loss.item()

                epoch_dice += self.get_dice_score(prev_masks,gt3D)      #change

                cur_loss = loss.item()

                loss /= self.args.accumulation_steps

                self.scaler.scale(loss).backward()    

            if step % self.args.accumulation_steps == 0 and step != 0:
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()

                print_loss = step_loss / self.args.accumulation_steps
                step_loss = 0
                print_dice = self.get_dice_score(prev_masks, gt3D)

            else:
                step_loss += cur_loss

            if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0):
                if step % self.args.accumulation_steps == 0 and step != 0:
                    print(f'Epoch: {epoch}, Step: {step}, Loss: {print_loss}, Dice: {print_dice}')
                    if print_dice > self.step_best_dice:
                        self.step_best_dice = print_dice
                        if print_dice > 0.9:
                            self.save_checkpoint(
                                epoch,
                                sam_model.state_dict(),
                                describe=f'{epoch}_step_dice:{print_dice}_best'
                            )
                    if print_loss < self.step_best_loss:
                        self.step_best_loss = print_loss

        epoch_loss /= step+1        #change
        epoch_dice /= step+1        #change

        return epoch_loss, epoch_iou, epoch_dice, pred_list
adrianzzk commented 8 months ago

这个是针对train.py中的train_epoch函数的修改

sdzcgaoxiang commented 8 months ago

谢谢,但这样就加大运算量了吧,倒是可以把每个epoch的monai-DIceCEloss的拆成DICE+BCE,把loss里的DICE当train_dice用。。但是要改的地方好像还挺多,期待作者下个版本更新吧。

adrianzzk commented 8 months ago

@sdzcgaoxiang 我也只是仅仅想让这个dice可以显示,并没有考虑到运算量的问题……

blueyo0 commented 8 months ago

thx, this bug report is accepted. And related fix is updated in new commits.