jiupinjia / stylized-neural-painting

Official Pytorch implementation of the preprint paper "Stylized Neural Painting", in CVPR 2021.
Creative Commons Zero v1.0 Universal
1.56k stars 262 forks source link

Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:2 #65

Closed nextyale-10 closed 3 months ago

nextyale-10 commented 4 months ago

Hi, I was trying to run demo_prog.py on different gpus. Thus I added an argument by this line of code parser.add_argument('--gpu',type=int,default=0,help="which gpu to use")

I also changed the GPU choosing code as device = torch.device(f"cuda:{args.gpu}").

However, it ends up with error as following:

Traceback (most recent call last):
  File "demo_prog.py", line 117, in <module>
  File "demo_prog.py", line 92, in optimize_x
  File "painter.py", line 270, in _forward_pass
    + self.G_pred_canvas * (1 - G_pred_alpha)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:0!

Even I changed the original code to device = torch.device("cpu"). It still has similar problem as following:

Traceback (most recent call last):
  File "demo_prog.py", line 117, in <module>
  File "demo_prog.py", line 92, in optimize_x
  File "painter.py", line 270, in _forward_pass
    + self.G_pred_canvas * (1 - G_pred_alpha)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

It seems some data keep using cuda:0...

for more information, I paste my entire code as below:

import argparse
import torch
import torch.optim as optim

from painter import *

# settings
parser = argparse.ArgumentParser(description='STYLIZED NEURAL PAINTING')
parser.add_argument('--img_path', type=str, default='./test_images/apple.jpg', metavar='str',
                    help='path to test image (default: ./test_images/apple.jpg)')
parser.add_argument('--renderer', type=str, default='oilpaintbrush', metavar='str',
                    help='renderer: [watercolor, markerpen, oilpaintbrush, rectangle (default oilpaintbrush)')
parser.add_argument('--canvas_color', type=str, default='black', metavar='str',
                    help='canvas_color: [black, white] (default black)')
parser.add_argument('--canvas_size', type=int, default=512, metavar='str',
                    help='size of the canvas for stroke rendering')
parser.add_argument('--keep_aspect_ratio', action='store_true', default=False,
                    help='keep input aspect ratio when saving outputs')
parser.add_argument('--max_m_strokes', type=int, default=500, metavar='str',
                    help='max number of strokes (default 500)')
parser.add_argument('--max_divide', type=int, default=5, metavar='N',
                    help='divide an image up-to max_divide x max_divide patches (default 5)')
parser.add_argument('--beta_L1', type=float, default=1.0,
                    help='weight for L1 loss (default: 1.0)')
parser.add_argument('--with_ot_loss', action='store_true', default=False,
                    help='imporve the convergence by using optimal transportation loss')
parser.add_argument('--beta_ot', type=float, default=0.1,
                    help='weight for optimal transportation loss (default: 0.1)')
parser.add_argument('--net_G', type=str, default='zou-fusion-net-light', metavar='str',
                    help='net_G: plain-dcgan, plain-unet, huang-net, zou-fusion-net, '
                         'or zou-fusion-net-light (default: zou-fusion-net-light)')
parser.add_argument('--renderer_checkpoint_dir', type=str, default=r'./checkpoints_G_oilpaintbrush_light', metavar='str',
                    help='dir to load neu-renderer (default: ./checkpoints_G_oilpaintbrush_light)')
parser.add_argument('--lr', type=float, default=0.002,
                    help='learning rate for stroke searching (default: 0.005)')
parser.add_argument('--output_dir', type=str, default=r'./output', metavar='str',
                    help='dir to save painting results (default: ./output)')
parser.add_argument('--disable_preview', action='store_true', default=False,
                    help='disable cv2.imshow, for running remotely without x-display')
parser.add_argument('--gpu',type=int,default=0,help="which gpu to use")
args = parser.parse_args()

# Decide which device we want to run on

device = torch.device(f"cuda:{args.gpu}")
# device = torch.device("cpu")

def optimize_x(pt):


    print('begin drawing...')

    PARAMS = np.zeros([1, 0, pt.rderr.d], np.float32)

    if pt.rderr.canvas_color == 'white':
        CANVAS_tmp = torch.ones([1, 3, pt.net_G.out_size, pt.net_G.out_size]).to(device)
        CANVAS_tmp = torch.zeros([1, 3, pt.net_G.out_size, pt.net_G.out_size]).to(device)

    for pt.m_grid in range(1, pt.max_divide + 1):

        pt.img_batch = utils.img2patches(pt.img_, pt.m_grid, pt.net_G.out_size).to(device)
        pt.G_final_pred_canvas = CANVAS_tmp

        pt.x_ctt.requires_grad = True
        pt.x_color.requires_grad = True
        pt.x_alpha.requires_grad = True
        utils.set_requires_grad(pt.net_G, False)

        pt.optimizer_x = optim.RMSprop([pt.x_ctt, pt.x_color, pt.x_alpha], lr=pt.lr, centered=True)

        pt.step_id = 0
        for pt.anchor_id in range(0, pt.m_strokes_per_block):
            iters_per_stroke = int(500 / pt.m_strokes_per_block)
            for i in range(iters_per_stroke):
                pt.G_pred_canvas = CANVAS_tmp

                # update x

                pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
                pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
                pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)


                pt.x_ctt.data = torch.clamp(pt.x_ctt.data, 0.1, 1 - 0.1)
                pt.x_color.data = torch.clamp(pt.x_color.data, 0, 1)
                pt.x_alpha.data = torch.clamp(pt.x_alpha.data, 0, 1)

                pt.step_id += 1

        v = pt._normalize_strokes(pt.x)
        v = pt._shuffle_strokes_and_reshape(v)
        PARAMS = np.concatenate([PARAMS, v], axis=1)
        CANVAS_tmp = pt._render(PARAMS, save_jpgs=False, save_video=False)
        CANVAS_tmp = utils.img2patches(CANVAS_tmp, pt.m_grid + 1, pt.net_G.out_size).to(device)

    final_rendered_image = pt._render(PARAMS, save_jpgs=True, save_video=False)

if __name__ == '__main__':

    pt = ProgressivePainter(args=args)

Thank you in advance!