cvlab-kaist / CAT-Seg

Official Implementation of "CAT-Seg🐱: Cost Aggregation for Open-Vocabulary Semantic Segmentation"
https://ku-cvlab.github.io/CAT-Seg/
MIT License
267 stars 25 forks source link

Code issue #3

Closed Hydragon516 closed 1 year ago

Hydragon516 commented 1 year ago

Thank you for the interesting research. However, it seems that there are some bugs in the code.

  1. datasets/prepare_voc.py
val_list = [osp.join(voc_path, "SegmentationClassAug", f + ".png")
            for f in np.loadtxt(osp.join(voc_path, "ImageSets/Segmentation/val.txt"), dtype=np.str).tolist()]

revised version :

val_list = [osp.join(voc_path, "SegmentationClassAug", f + ".png")
            for f in np.loadtxt(osp.join(voc_path, "ImageSets/Segmentation/val.txt"), dtype=str).tolist()]
  1. cat_seg/cat_seg_model.py
def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]):
        ...

        image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze()
        image = rearrange(unfold(image), "(C H W) L-> L C H W", C=3, H=kernel)
        global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False)
        image = torch.cat((image, global_image), dim=0)
        ...

        global_output = outputs[-1:]
        global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,)
        outputs = outputs[:-1]
        outputs = fold(outputs.flatten(1).T) / fold(unfold(torch.ones([1] + out_res, device=self.device)))
        outputs = (outputs + global_output) / 2.

revised version :

def inference_sliding_window(self, batched_inputs, kernel=384, overlap=0.333, out_res=[640, 640]):
        ...

        image = F.interpolate(images[0].unsqueeze(0), size=out_res, mode='bilinear', align_corners=False).squeeze()
        image = unfold(image).squeeze(0)
        image = rearrange(image, "(C H W) L-> L C H W", C=3, H=kernel)
        global_image = F.interpolate(images[0].unsqueeze(0), size=(kernel, kernel), mode='bilinear', align_corners=False)
        image = torch.cat((image, global_image), dim=0)
        ...

        global_output = outputs[-1:]
        global_output = F.interpolate(global_output, size=out_res, mode='bilinear', align_corners=False,)
        outputs = outputs[:-1]
        outputs = fold(outputs.flatten(1).T)
        outputs = (outputs + global_output) / 2.

These bugs might be due to differences in Python and PyTorch versions.

hsshin98 commented 1 year ago

Hi, and thanks for pointing out the issues in our code! These does seem like issues with Pytorch and Python version, and sharing your versions could also help. For our project, we used torch=1.13 and python=3.8.13, and also did not have problems with torch=2.0, so using this configuration might help for now.