jeya-maria-jose / UNeXt-pytorch

Official Pytorch Code base for "UNeXt: MLP-based Rapid Medical Image Segmentation Network", MICCAI 2022
https://jeya-maria-jose.github.io/UNext-web/
MIT License
459 stars 76 forks source link

sigmoid(): argument 'input' (position 1) must be Tensor, not numpy.ndarray #15

Closed liuyx599 closed 1 year ago

liuyx599 commented 2 years ago

In post_process.py

 with torch.no_grad():
        for input, target, meta in tqdm(val_loader, total=len(val_loader)):
            input = input.cuda()
            target = target.cuda()
            model = model.cuda()
            # compute output

            if count<=5:
                start = time.time()
                if config['deep_supervision']:
                    output = model(input)[-1]
                else:
                    output = model(input)
                stop = time.time()

                gput.update(stop-start, input.size(0))

                start = time.time()
                model = model.cpu()
                input = input.cpu()
                output = model(input)
                stop = time.time()

                cput.update(stop-start, input.size(0))
                count=count+1

            iou,dice = iou_score(output, target)
            iou_avg_meter.update(iou, input.size(0))
            dice_avg_meter.update(dice, input.size(0))

            output = torch.sigmoid(output).cpu().numpy()   # error

I finished the training on BUSI and no error message was reported; then I verified in _postprocess.py and an error was reported when count exceeded 5. I found that output is only within the range of if, which means that when count exceeds 5, it does not feed input into the model to get output, so when the statement runs to output = torch.sigmoid(output).cpu().numpy() it generates an error?

Zzs0720 commented 1 year ago

Hello, I also encountered this problem, please do you solve it?

demondemoliu commented 1 year ago

Hello, I also encountered the same problem, could you please solve the code?

jeya-maria-jose commented 1 year ago

I did not use the postprocess.py; It was from the UNET++ code base upon which this code is built on.