LiheYoung / UniMatch

[CVPR 2023] Revisiting Weak-to-Strong Consistency in Semi-Supervised Semantic Segmentation
https://arxiv.org/abs/2208.09910
MIT License
456 stars 59 forks source link

你好,我想问一下,可以开源一下,推理的文件么就是可视化的文件代码? #18

Closed yang654123 closed 1 year ago

yang654123 commented 1 year ago

谢谢您

LiheYoung commented 1 year ago

你好,只需要在我们的evaluate函数里加几行可视化的代码即可,比如可以在得到的pred后使用torchshow工具来进行可视化。

yang654123 commented 1 year ago

你好,可以帮我看一下我这个代码哪里出了问题么? 我首先在resnet文件里加入了 我训练好了的resnet101.pth精度在75左右。然后我根据您的建议改写了一下可视化的代码,可是出现了一下问题,您能否帮助我解决一下这个难题。谢谢 代码: parser = argparse.ArgumentParser(description='Semi-Supervised Semantic Segmentation') parser.add_argument('--config', type=str, required=True) parser.add_argument('--labeled-id-path', type=str, required=True) parser.add_argument('--unlabeled-id-path', type=str, default=None) parser.add_argument('--save-path', type=str, required=True) parser.add_argument('--local_rank', default=0, type=int) parser.add_argument('--port', default=None, type=int) save_folder='viewer' colorfold=os.path.join(save_folder, "color") gray_folder=os.path.join(save_folder, "gray")

os.path.join(save_folder, "color")

def colorful(mask, colormap): color_mask = np.zeros([mask.shape[0], mask.shape[1], 3]) for i in np.unique(mask): color_mask[mask == i] = colormap[i]

return np.uint8(color_mask)

def create_pascal_label_colormap(): """Creates a label colormap used in Pascal segmentation benchmark. Returns: A colormap for visualizing segmentation results. """ colormap = 255 * np.ones((256, 3), dtype=np.uint8) colormap[0] = [0, 0, 0] colormap[1] = [128, 0, 0] colormap[2] = [0, 128, 0] colormap[3] = [128, 128, 0] colormap[4] = [0, 0, 128] colormap[5] = [128, 0, 128] colormap[6] = [0, 128, 128] colormap[7] = [128, 128, 128] colormap[8] = [64, 0, 0] colormap[9] = [192, 0, 0] colormap[10] = [64, 128, 0] colormap[11] = [192, 128, 0] colormap[12] = [64, 0, 128] colormap[13] = [192, 0, 128] colormap[14] = [64, 128, 128] colormap[15] = [192, 128, 128] colormap[16] = [0, 64, 0] colormap[17] = [128, 64, 0] colormap[18] = [0, 192, 0] colormap[19] = [128, 192, 0] colormap[20] = [0, 64, 128] return colormap

def color_map(mask, colormap): color_mask = np.zeros([mask.shape[0], mask.shape[1], 3]) for i in np.unique(mask): color_mask[mask == i] = colormap[i] return np.uint8(color_mask) @torch.no_grad() def net_process(model, image): input = image.cuda() output = model(input) return output mean, std = [123.675, 116.28, 103.53], [58.395, 57.12, 57.375] input_scale=[321,321] data_list='data/' def evaluate(model, loader, mode, cfg): model.eval()

assert mode in ['original', 'center_crop', 'sliding_window']

# intersection_meter = AverageMeter()
# union_meter = AverageMeter()
colormap = create_pascal_label_colormap()
with torch.no_grad():
    for img, mask, id in loader:
        # image_name = image_path.split("/")[-1]
        print(id)
        # print(id[0])
        id=os.path.join(data_list,id[0])
        id=id.split(" ")[0]
        print(id)
        img = Image.open(id).convert("RGB")
        img = np.asarray(img).astype(np.float32)
        h, w, _ = img.shape
        img = (img - mean) / std
        img = torch.Tensor(img).permute(2, 0, 1)
        img = img.unsqueeze(dim=0)
        img = F.interpolate(img, input_scale, mode="bilinear", align_corners=True)

        output = net_process(model, img)
        output = F.interpolate(output, (h, w), mode="bilinear", align_corners=True)
        mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()

        color_mask = Image.fromarray(colorful(mask, colormap))
        id=id.split("/")[-1]
        color_mask.save(os.path.join(colorfold, id))
        #
        # mask = Image.fromarray(mask)
        # mask.save(os.path.join(gray_folder, id))

def main(): args = parser.parse_args()

cfg = yaml.load(open(args.config, "r"), Loader=yaml.Loader)

logger = init_log('global', logging.INFO)
logger.propagate = 0

cudnn.enabled = True
cudnn.benchmark = True

model = DeepLabV3Plus(cfg)

model.cuda()

valset = SemiDataset(cfg['dataset'], cfg['data_root'], 'val')

valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=2,
                       drop_last=False)
eval_mode='sliding_window'
evaluate(model, valloader, eval_mode, cfg)

if name == 'main': main()

输出: image

LiheYoung commented 1 year ago

在安装好上述的torchshow后,可以用torchshow.save(pred[0], './vis/img_id.png')来保存预测的mask。

tanveer6715 commented 1 year ago

Hello, can you help me to see what is wrong with my code? I first added my trained resnet101.pth to the resnet file with an accuracy of around 75. Then I rewrote the visualization code according to your suggestion, but there is a problem, can you help me solve this problem. Thanks Code: parser = argparse.ArgumentParser(description='Semi-Supervised Semantic Segmentation') parser.add_argument('--config', type=str, required=True) parser.add_argument('--labeled-id-path' , type=str, required=True) parser.add_argument('--unlabeled-id-path', type=str, default=None) parser.add_argument('--save-path', type=str, required=True ) parser.add_argument('--local_rank', default=0, type=int) parser.add_argument('--port', default=None, type=int) save_folder='viewer' colorfold=os.path.join( save_folder, "color"

os.path.join(save_folder, "color")

def colorful(mask, colormap): color_mask = np.zeros([mask.shape[0], mask.shape[1], 3]) for i in np.unique(mask): color_mask[mask == i] = colormap[i]

return np.uint8(color_mask)

def create_pascal_label_colormap(): """Creates a label colormap used in Pascal segmentation benchmark. Returns: A colormap for visualizing segmentation results. """ colormap = 255 * np.ones((256, 3), dtype=np.uint8) colormap[0] = [0, 0, 0] colormap[1] = [128, 0, 0] colormap[2] = [0, 128, 0] colormap[3] = [128, 128, 0] colormap[4] = [0, 0, 128] colormap[5] = [128, 0, 128] colormap[6] = [0, 128, 128] colormap[7] = [128, 128, 128] colormap[8] = [64, 0, 0] colormap[9] = [192, 0, 0] colormap[10] = [64, 128, 0] colormap[11] = [192, 128, 0] colormap[12] = [64, 0, 128] colormap[13] = [192, 0, 128] colormap[14] = [64, 128, 128] colormap[15] = [192, 128, 128] colormap[16] = [0, 64, 0] colormap[17] = [128, 64, 0] colormap[18] = [0, 192, 0] colormap[19] = [128, 192, 0] colormap[20] = [0, 64, 128] return colormap

def color_map(mask, colormap): color_mask = np.zeros([mask.shape[0], mask.shape[1], 3]) for i in np.unique(mask): color_mask[mask == i] = colormap[i] return np.uint8(color_mask) @torch.no_grad() def net_process(model, image): input = image.cuda() output = model(input) return output mean, std = [123.675, 116.28, 103.53], [58.395, 57.12, 57.375] input_scale=[321,321] data_list='data/' def evaluate(model, loader, mode, cfg): model.eval() # assert mode in ['original', 'center_crop', 'sliding_window'] # intersection_meter = AverageMeter() # union_meter = AverageMeter() colormap = create_pascal_label_colormap() with torch.no_grad(): for img, mask, id in loader: # image_name = image_path.split("/")[-1] print(id) # print(id[0]) id=os.path.join(datalist,id[0]) id=id.split(" ")[0] print(id) img = Image.open(id).convert("RGB") img = np.asarray(img).astype(np.float32) h, w, = img.shape img = (img - mean) / std img = torch.Tensor(img).permute(2, 0, 1) img = img.unsqueeze(dim=0) img = F.interpolate(img, input_scale, mode="bilinear", align_corners=True)

        output = net_process(model, img)
        output = F.interpolate(output, (h, w), mode="bilinear", align_corners=True)
        mask = torch.argmax(output, dim=1).squeeze().cpu().numpy()

        color_mask = Image.fromarray(colorful(mask, colormap))
        id=id.split("/")[-1]
        color_mask.save(os.path.join(colorfold, id))
        #
        # mask = Image.fromarray(mask)
        # mask.save(os.path.join(gray_folder, id))

def main(): args = parser.parse_args()

cfg = yaml.load(open(args.config, "r"), Loader=yaml.Loader)

logger = init_log('global', logging.INFO)
logger.propagate = 0

cudnn.enabled = True
cudnn.benchmark = True

model = DeepLabV3Plus(cfg)

model.cuda()

valset = SemiDataset(cfg['dataset'], cfg['data_root'], 'val')

valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=2,
                       drop_last=False)
eval_mode='sliding_window'
evaluate(model, valloader, eval_mode, cfg)

if name == 'main': main()

output: image

Would you provide proper visualization code that you have used to visualize the results? It will be very helpful. Thank you