Closed Sba-Stuff closed 4 years ago
I Modified Code A Bit, Works like a charm. Thanks. ` from future import print_function
import os, glob, time, argparse, pdb, cv2
import numpy as np from skimage.measure import label
import torch import torch.nn as nn from torch.autograd import Variable import torch.backends.cudnn as cudnn
from functions import * from networks import ResnetConditionHR
torch.set_num_threads(1)
print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])
"""Parses arguments.""" parser = argparse.ArgumentParser(description='Background Matting.') parser.add_argument('-m', '--trained_model', type=str, default='real-fixed-cam',choices=['real-fixed-cam', 'real-hand-held', 'syn-comp-adobe'],help='Trained background matting model') parser.add_argument('-o', '--output_dir', type=str, required=True,help='Directory to save the output results. (required)') parser.add_argument('-i', '--input_dir', type=str, required=True,help='Directory to load input images. (required)') parser.add_argument('-tb', '--target_back', type=str,help='Directory to load the target background.') parser.add_argument('-b', '--back', type=str,default=None,help='Captured background image. (only use for inference on videos with fixed camera')
args=parser.parse_args()
model_main_dir='Models/' + args.trained_model + '/';
data_path=args.input_dir
if os.path.isdir(args.target_back): args.video=True print('Using video mode') else: args.video=False print('Using image mode')
back_img10=cv2.imread(args.target_back); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB);
#Green-screen background
back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155;
fo=glob.glob(model_main_dir + 'netGepoch*') model_name1=fo[0] netM=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=7,n_blocks2=3) netM=nn.DataParallel(netM) netM.load_state_dict(torch.load(model_name1)) netM.cpu(); netM.eval() cudnn.benchmark=True reso=(512,512) #input reoslution to the network
if args.back is not None: bg_im0=cv2.imread(args.back); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB);
test_imgs = [f for f in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, f)) and f.endswith('_img.png')] test_imgs.sort()
result_path=args.output_dir
if not os.path.exists(result_path): os.makedirs(result_path)
for i in range(0,len(test_imgs)): filename = test_imgs[i]
bgr_img = cv2.imread(os.path.join(data_path, filename)); bgr_img=cv2.cvtColor(bgr_img,cv2.COLOR_BGR2RGB);
if args.back is None:
#captured background image
bg_im0=cv2.imread(os.path.join(data_path, filename.replace('_img','_back'))); bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB);
#segmentation mask
rcnn = cv2.imread(os.path.join(data_path, filename.replace('_img','_masksDL')),0);
if args.video: #if video mode, load target background frames
#target background path
back_img10=cv2.imread(os.path.join(args.target_back,filename.replace('_img.png','.png'))); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB);
#Green-screen background
back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155;
#create multiple frames with adjoining frames
gap=20
multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4))
idx=[i-2*gap,i-gap,i+gap,i+2*gap]
for t in range(0,4):
if idx[t]<0:
idx[t]=len(test_imgs)+idx[t]
elif idx[t]>=len(test_imgs):
idx[t]=idx[t]-len(test_imgs)
file_tmp=test_imgs[idx[t]]
bgr_img_mul = cv2.imread(os.path.join(data_path, file_tmp));
multi_fr_w[...,t]=cv2.cvtColor(bgr_img_mul,cv2.COLOR_BGR2GRAY);
else:
## create the multi-frame
multi_fr_w=np.zeros((bgr_img.shape[0],bgr_img.shape[1],4))
multi_fr_w[...,0] = cv2.cvtColor(bgr_img,cv2.COLOR_BGR2GRAY);
multi_fr_w[...,1] = multi_fr_w[...,0]
multi_fr_w[...,2] = multi_fr_w[...,0]
multi_fr_w[...,3] = multi_fr_w[...,0]
#crop tightly
bgr_img0=bgr_img; print(str(bgr_img0.shape[0])+"\n"+str(bgr_img0.shape[1]))
bbox=get_bbox(rcnn,R=bgr_img0.shape[0],C=bgr_img0.shape[1])
crop_list=[bgr_img,bg_im0,rcnn,back_img10,back_img20,multi_fr_w]
crop_list=crop_images(crop_list,reso,bbox)
bgr_img=crop_list[0]; bg_im=crop_list[1]; rcnn=crop_list[2]; back_img1=crop_list[3]; back_img2=crop_list[4]; multi_fr=crop_list[5]
#process segmentation mask
kernel_er = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
kernel_dil = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
rcnn=rcnn.astype(np.float32)/255; rcnn[rcnn>0.2]=1;
K=25
zero_id=np.nonzero(np.sum(rcnn,axis=1)==0)
del_id=zero_id[0][zero_id[0]>250]
if len(del_id)>0:
del_id=[del_id[0]-2,del_id[0]-1,*del_id]
rcnn=np.delete(rcnn,del_id,0)
rcnn = cv2.copyMakeBorder( rcnn, 0, K + len(del_id), 0, 0, cv2.BORDER_REPLICATE)
rcnn = cv2.erode(rcnn, kernel_er, iterations=10)
rcnn = cv2.dilate(rcnn, kernel_dil, iterations=5)
rcnn=cv2.GaussianBlur(rcnn.astype(np.float32),(31,31),0)
rcnn=(255*rcnn).astype(np.uint8)
rcnn=np.delete(rcnn, range(reso[0],reso[0]+K), 0)
#convert to torch
img=torch.from_numpy(bgr_img.transpose((2, 0, 1))).unsqueeze(0); img=2*img.float().div(255)-1
bg=torch.from_numpy(bg_im.transpose((2, 0, 1))).unsqueeze(0); bg=2*bg.float().div(255)-1
rcnn_al=torch.from_numpy(rcnn).unsqueeze(0).unsqueeze(0); rcnn_al=2*rcnn_al.float().div(255)-1
multi_fr=torch.from_numpy(multi_fr.transpose((2, 0, 1))).unsqueeze(0); multi_fr=2*multi_fr.float().div(255)-1
with torch.no_grad():
img,bg,rcnn_al, multi_fr =Variable(img.cpu()), Variable(bg.cpu()), Variable(rcnn_al.cpu()), Variable(multi_fr.cpu())
input_im=torch.cat([img,bg,rcnn_al,multi_fr],dim=1)
alpha_pred,fg_pred_tmp=netM(img,bg,rcnn_al,multi_fr)
al_mask=(alpha_pred>0.95).type(torch.FloatTensor)
# for regions with alpha>0.95, simply use the image as fg
fg_pred=img*al_mask + fg_pred_tmp*(1-al_mask)
alpha_out=to_image(alpha_pred[0,...]);
#refine alpha with connected component
labels=label((alpha_out>0.05).astype(int))
try:
assert( labels.max() != 0 )
except:
continue
largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
alpha_out=alpha_out*largestCC
alpha_out=(255*alpha_out[...,0]).astype(np.uint8)
fg_out=to_image(fg_pred[0,...]); fg_out=fg_out*np.expand_dims((alpha_out.astype(float)/255>0.01).astype(float),axis=2); fg_out=(255*fg_out).astype(np.uint8)
#Uncrop
R0=bgr_img0.shape[0];C0=bgr_img0.shape[1]
alpha_out0=uncrop(alpha_out,bbox,R0,C0)
fg_out0=uncrop(fg_out,bbox,R0,C0)
#compose
back_img10=cv2.resize(back_img10,(C0,R0)); back_img20=cv2.resize(back_img20,(C0,R0))
comp_im_tr1=composite4(fg_out0,back_img10,alpha_out0)
comp_im_tr2=composite4(fg_out0,back_img20,alpha_out0)
cv2.imwrite(result_path+'/'+filename.replace('_img','_out'), alpha_out0)
cv2.imwrite(result_path+'/'+filename.replace('_img','_fg'), cv2.cvtColor(fg_out0,cv2.COLOR_BGR2RGB))
cv2.imwrite(result_path+'/'+filename.replace('_img','_compose'), cv2.cvtColor(comp_im_tr1,cv2.COLOR_BGR2RGB))
cv2.imwrite(result_path+'/'+filename.replace('_img','_matte').format(i), cv2.cvtColor(comp_im_tr2,cv2.COLOR_BGR2RGB))
print('Done: ' + str(i+1) + '/' + str(len(test_imgs)))
`
Hi There!! I have tried this code on Intel PC, without Nividia CPU and CUDA. Is it possible to convert the code so that it works on CPU?
Right Now it is giving error at:
netM.cuda(); netM.eval() ==> Assertion Error: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx.
Can we convert this line to CPU so that it works on CPU without Nividia CUDA and GPU? like netM.cpu(); netM.eval()? or netM.to('cpu') something?