I want to run the demo. #103

Open HansolEom opened 4 years ago

HansolEom commented 4 years ago

I want to visualize how this model segments in a single image. Is there a way?

sunke123 commented 4 years ago

When you testing, you can set "sv_pred=True".

a79687417 commented 4 years ago

Is there a way to test your model on my own single image?

Pro-and-Khan commented 4 years ago

I am trying to use the pre-trained model 'hrnet_ocr_cs_trainval_8227_torch11.pth'. I can't seem to find the appropriate cfg argument related to the HRnet model:

ZhuangPeng97 commented 4 years ago

hi, have you implemented it? i want to test on my own single image.

dreamlychina commented 3 years ago

import argparse

from lib.config import config from lib.config import update_config_demo import lib.models.seg_hrnet as seg_models

import torch import torch.nn as nn import torch.backends.cudnn as cudnn import cv2 from PIL import Image import numpy as np from torch.nn import functional as F

mean=[0.485, 0.456, 0.406] std=[0.229, 0.224, 0.225] @torch.no_grad()

class FaceSeg(): def init(self,cfg_file='./experiments/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484.yaml',weights='',device = 'cpu',imgsz=[700,700], num_classes=4):#include background

cudnn related setting


cudnn.benchmark = config.CUDNN.BENCHMARK
cudnn.deterministic = config.CUDNN.DETERMINISTIC
cudnn.enabled = config.CUDNN.ENABLED

# build model
if torch.__version__.startswith('1'):
    module = seg_models
    module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
model = module.get_seg_model(config)

dump_input = torch.rand(
    (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])

if config.TEST.MODEL_FILE:
    model_state_file = config.TEST.MODEL_FILE
    print("cant find model_file: ",config.TEST.MODEL_FILE)

pretrained_dict = torch.load(model_state_file)
if 'state_dict' in pretrained_dict:
    pretrained_dict = pretrained_dict['state_dict']
model_dict = model.state_dict()
pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
                    if k[6:] in model_dict.keys()}
if device != 'cpu':
    gpus = list(config.GPUS)
    model = nn.DataParallel(model, device_ids=gpus).cuda()
    print("use cpu seg")

self.label_mapping={-1: ignore_label, 0: ignore_label, 
                      1: ignore_label, 2: ignore_label, 
                      3: ignore_label, 4: ignore_label, 
                      5: ignore_label, 6: ignore_label, 
                      7: 0, 8: 1, 9: ignore_label, 
                      10: ignore_label, 11: 2, 12: 3, 
                      13: 4, 14: ignore_label, 15: ignore_label, 
                      16: ignore_label, 17: 5, 18: ignore_label, 
                      19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
                      25: 12, 26: 13, 27: 14, 28: 15, 
                      29: ignore_label, 30: ignore_label, 
                      31: 16, 32: 17, 33: 18}

def run(self,img0): confusion_matrix = np.zeros((config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))


image_nor = img0.astype(np.float32)[:, :, ::-1]
image_nor = image_nor / 255.0
image_nor -= mean
image_nor /= std
ori_height, ori_width, _ = img0.shape

image = image_nor.copy()
stride_h =[0] * 1.0)
stride_w =[1] * 1.0)

final_pred = torch.zeros([1, self.num_classes,ori_height,ori_width])
new_img=cv2.resize(image, (self.crop_size[0],self.crop_size[1]),interpolation=cv2.INTER_LINEAR)
height, width = new_img.shape[:-1]

new_img = new_img.transpose((2, 0, 1))
new_img = np.expand_dims(new_img, axis=0)
new_img = torch.from_numpy(new_img)

preds = self.model(new_img)
new_size = new_img.size()
print("new size",new_size)
preds = F.interpolate(
    input=preds, size=new_size[-2:],
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS

preds = preds[:, :, 0:height, 0:width]

preds = F.interpolate(
    preds, (ori_height, ori_width), 
    mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
final_pred += preds

if visual:
    palette = self.get_palette(256)
    preds = np.asarray(np.argmax(preds.detach().cpu(), axis=1), dtype=np.uint8)
    for i in range(preds.shape[0]):
        pred = self.convert_label(preds[i], inverse=True)
        save_img = Image.fromarray(pred)

def get_palette(self, n): palette = [0] (n 3) for j in range(0, n): lab = j palette[j 3 + 0] = 0 palette[j 3 + 1] = 0 palette[j 3 + 2] = 0 i = 0 while lab: palette[j 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) palette[j 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) palette[j 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) i += 1 lab >>= 3 return palette

def convert_label(self, label, inverse=False): temp = label.copy() if inverse: for v, k in self.label_mapping.items(): label[temp == k] = v else: for k, v in self.label_mapping.items(): label[temp == k] = v return label if name == "main": face_segt=FaceSeg(weights="your_path/cityscapes/seg_hrnet_w48_train_512x1024_sgd_lr1e-2_wd5e-4_bs_12_epoch484/300_checkpoint.pth.tar") img=cv2.imread("yours.png")