danielroich / PTI

Official Implementation for "Pivotal Tuning for Latent-based editing of Real Images" (ACM TOG 2022) https://arxiv.org/abs/2106.05744
MIT License
897 stars 112 forks source link

GPU error #45

Closed luchaoqi closed 1 year ago

luchaoqi commented 1 year ago

Hello, I was trying to implement PTI into eg3d/loss.py at main · NVlabs/eg3d but I got some problems when calling the PTI/training/projectors at main · danielroich/PTI

So here is how I call w/w_plus projector (search for function pti_projector):

eg3d code ``` # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. """Loss functions.""" import numpy as np import torch from torch_utils import training_stats from torch_utils.ops import conv2d_gradfix from torch_utils.ops import upfirdn2d from training.dual_discriminator import filtered_resizing #---------------------------------------------------------------------------- class Loss: def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass raise NotImplementedError() #---------------------------------------------------------------------------- # ---------------------- project image into latent space --------------------- # # modified code from https://github.com/oneThousand1000/EG3D-projector/tree/master/eg3d/projector from training.projector import w_plus_projector, w_projector from torchvision import transforms import copy def pti_projector(cur_G, cur_c, cur_image, device, latent_type='w_plus'): # # put image back to cpu for transforms # image = cur_image.cpu() # # normalize image # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], # std=[0.5, 0.5, 0.5]) # id_image = normalize(image) # id_image = torch.squeeze((id_image + 1) / 2, 0) id_image = cur_image.to(device) # c = c.to(device) c = torch.reshape(cur_c, (1, 25)).to(device) # 25 is the camera pose dimension 16 + 9 G = cur_G if latent_type == 'w_plus': w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600) else: w = w_projector.project(G, c, id_image, device=device, w_avg_samples=600) print('w shape: ', w.shape) return w class StyleGAN2Loss(Loss): def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased'): super().__init__() self.device = device self.G = G self.D = D self.augment_pipe = augment_pipe self.r1_gamma = r1_gamma self.style_mixing_prob = style_mixing_prob self.pl_weight = pl_weight self.pl_batch_shrink = pl_batch_shrink self.pl_decay = pl_decay self.pl_no_weight_grad = pl_no_weight_grad self.pl_mean = torch.zeros([], device=device) self.blur_init_sigma = blur_init_sigma self.blur_fade_kimg = blur_fade_kimg self.r1_gamma_init = r1_gamma_init self.r1_gamma_fade_kimg = r1_gamma_fade_kimg self.neural_rendering_resolution_initial = neural_rendering_resolution_initial self.neural_rendering_resolution_final = neural_rendering_resolution_final self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg self.gpc_reg_fade_kimg = gpc_reg_fade_kimg self.gpc_reg_prob = gpc_reg_prob self.dual_discrimination = dual_discrimination self.filter_mode = filter_mode self.resample_filter = upfirdn2d.setup_filter([1,3,3,1], device=device) self.blur_raw_target = True assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1) def run_G(self, z, c, swapping_prob, neural_rendering_resolution, update_emas=False): if swapping_prob is not None: c_swapped = torch.roll(c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand((c.shape[0], 1), device=c.device) < swapping_prob, c_swapped, c) else: c_gen_conditioning = torch.zeros_like(c) ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] gen_output = self.G.synthesis(ws, c, neural_rendering_resolution=neural_rendering_resolution, update_emas=update_emas) return gen_output, ws def run_D(self, img, c, blur_sigma=0, blur_sigma_raw=0, update_emas=False): blur_size = np.floor(blur_sigma * 3) if blur_size > 0: with torch.autograd.profiler.record_function('blur'): f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div(blur_sigma).square().neg().exp2() img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) if self.augment_pipe is not None: augmented_pair = self.augment_pipe(torch.cat([img['image'], torch.nn.functional.interpolate(img['image_raw'], size=img['image'].shape[2:], mode='bilinear', antialias=True)], dim=1)) img['image'] = augmented_pair[:, :img['image'].shape[1]] img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], size=img['image_raw'].shape[2:], mode='bilinear', antialias=True) logits = self.D(img, c, update_emas=update_emas) return logits def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] if self.G.rendering_kwargs.get('density_reg', 0) == 0: phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) if self.r1_gamma == 0: phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 r1_gamma = self.r1_gamma alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1 swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None if self.neural_rendering_resolution_final is not None: alpha = min(cur_nimg / (self.neural_rendering_resolution_fade_kimg * 1e3), 1) neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * (1 - alpha) + self.neural_rendering_resolution_final * alpha)) else: neural_rendering_resolution = self.neural_rendering_resolution_initial real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, filter_mode=self.filter_mode) if self.blur_raw_target: blur_size = np.floor(blur_sigma * 3) if blur_size > 0: f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div(blur_sigma).square().neg().exp2() real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) real_img = {'image': real_img, 'image_raw': real_img_raw} # run PTI to get w/w_plus latent codes for real images # print(real_img.shape, real_c.shape, gen_z.shape, gen_c.shape) # torch.Size([8, 3, 512, 512]) torch.Size([8, 25]) torch.Size([8, 512]) torch.Size([8, 25]) # convert gen_z to real_z batch_size = real_img['image'].shape[0] real_z = [] for i in range(batch_size): cur_img = real_img['image'][i] cur_c = real_c[i] cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device) real_z.append(cur_z) real_z = torch.stack(real_z) print('real_z', real_z.shape) # Gmain: Maximize logits for generated images. if phase in ['Gmain', 'Gboth']: with torch.autograd.profiler.record_function('Gmain_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution) gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Gmain = torch.nn.functional.softplus(-gen_logits) training_stats.report('Loss/G/loss', loss_Gmain) with torch.autograd.profiler.record_function('Gmain_backward'): loss_Gmain.mean().mul(gain).backward() # Density Regularization if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'l1': if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * self.G.rendering_kwargs['density_reg_p_dist'] all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg'] TVloss.mul(gain).backward() # Alternative density regularization if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-detach': if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] monotonic_loss = torch.relu(sigma_initial.detach() - sigma_perturbed).mean() * 10 monotonic_loss.mul(gain).backward() if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp'] all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg'] TVloss.mul(gain).backward() # Alternative density regularization if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-fixed': if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] monotonic_loss = torch.relu(sigma_initial - sigma_perturbed).mean() * 10 monotonic_loss.mul(gain).backward() if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp'] all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg'] TVloss.mul(gain).backward() # Dmain: Minimize logits for generated images. loss_Dgen = 0 if phase in ['Dmain', 'Dboth']: with torch.autograd.profiler.record_function('Dgen_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution, update_emas=True) gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Dgen = torch.nn.functional.softplus(gen_logits) with torch.autograd.profiler.record_function('Dgen_backward'): loss_Dgen.mean().mul(gain).backward() # Dmain: Maximize logits for real images. # Dr1: Apply R1 regularization. if phase in ['Dmain', 'Dreg', 'Dboth']: name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' with torch.autograd.profiler.record_function(name + '_forward'): real_img_tmp_image = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw} real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma) training_stats.report('Loss/scores/real', real_logits) training_stats.report('Loss/signs/real', real_logits.sign()) loss_Dreal = 0 if phase in ['Dmain', 'Dboth']: loss_Dreal = torch.nn.functional.softplus(-real_logits) training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) loss_Dr1 = 0 if phase in ['Dreg', 'Dboth']: if self.dual_discrimination: with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image'], real_img_tmp['image_raw']], create_graph=True, only_inputs=True) r1_grads_image = r1_grads[0] r1_grads_image_raw = r1_grads[1] r1_penalty = r1_grads_image.square().sum([1,2,3]) + r1_grads_image_raw.square().sum([1,2,3]) else: # single discrimination with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image']], create_graph=True, only_inputs=True) r1_grads_image = r1_grads[0] r1_penalty = r1_grads_image.square().sum([1,2,3]) loss_Dr1 = r1_penalty * (r1_gamma / 2) training_stats.report('Loss/r1_penalty', r1_penalty) training_stats.report('Loss/D/reg', loss_Dr1) with torch.autograd.profiler.record_function(name + '_backward'): (loss_Dreal + loss_Dr1).mean().mul(gain).backward() #---------------------------------------------------------------------------- ```

and here is the modified projector scripts:

w_plus_projector.py ``` # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Project given image to the latent space of pretrained network pickle.""" import copy import os import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm import dnnlib import PIL from camera_utils import LookAtPoseSampler def project( G, c, # outdir, target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution *, num_steps=1000, w_avg_samples=10000, initial_learning_rate=0.01, initial_noise_factor=0.05, lr_rampdown_length=0.25, lr_rampup_length=0.05, noise_ramp_length=0.75, regularize_noise_weight=1e5, verbose=False, device: torch.device, initial_w=None, image_log_step=100, # w_name: str ): # os.makedirs(f'{outdir}/{w_name}_w_plus', exist_ok=True) # outdir = f'{outdir}/{w_name}_w_plus' assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) def logprint(*args): if verbose: print(*args) G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore # Compute w stats. w_avg_path = './w_avg.npy' w_std_path = './w_std.npy' if (not os.path.exists(w_avg_path)) or (not os.path.exists(w_std_path)): print(f'Computing W midpoint and stddev using {w_avg_samples} samples...') z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) # c_samples = c.repeat(w_avg_samples, 1) # use avg look at point camera_lookat_point = torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device) cam2world_pose = LookAtPoseSampler.sample(3.14 / 2, 3.14 / 2, camera_lookat_point, radius=G.rendering_kwargs['avg_camera_radius'], device=device) focal_length = 4.2647 # FFHQ's FOV intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) c_samples = c_samples.repeat(w_avg_samples, 1) w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C] w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] # print('save w_avg to ./w_avg.npy') # np.save('./w_avg.npy',w_avg) w_avg_tensor = torch.from_numpy(w_avg).cuda() w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 # np.save(w_avg_path, w_avg) # np.save(w_std_path, w_std) else: # w_avg = np.load(w_avg_path) # w_std = np.load(w_std_path) raise Exception(' ') # z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) # c_samples = c.repeat(w_avg_samples, 1) # w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C] # w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] # w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] # w_avg_tensor = torch.from_numpy(w_avg).cuda() # w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 start_w = initial_w if initial_w is not None else w_avg # Setup noise inputs. noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name} # Load VGG16 feature detector. url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' # url = './networks/vgg16.pt' with dnnlib.util.open_url(url) as f: vgg16 = torch.jit.load(f, map_location=device).eval().to(device) # Features for target image. target_images = target.unsqueeze(0).to(device).to(torch.float32) if target_images.shape[2] > 256: target_images = F.interpolate(target_images, size=(256, 256), mode='area') target_features = vgg16(target_images, resize_images=False, return_lpips=True) start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1) w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=0.1) # Init noise. for buf in noise_bufs.values(): buf[:] = torch.randn_like(buf) buf.requires_grad = True for step in tqdm(range(num_steps)): # Learning rate schedule. t = step / num_steps w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) lr = initial_learning_rate * lr_ramp for param_group in optimizer.param_groups: param_group['lr'] = lr # Synth images from opt_w. w_noise = torch.randn_like(w_opt) * w_noise_scale ws = (w_opt + w_noise) synth_images = G.synthesis(ws,c, noise_mode='const')['image'] # if step % image_log_step == 0: # with torch.no_grad(): # vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) # PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}.png') # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. synth_images = (synth_images + 1) * (255 / 2) if synth_images.shape[2] > 256: synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') # Features for synth images. synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) dist = (target_features - synth_features).square().sum() # Noise regularization. reg_loss = 0.0 for v in noise_bufs.values(): noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() while True: reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) loss = dist + reg_loss * regularize_noise_weight # if step % 10 == 0: # with torch.no_grad(): # print({f'step {step}, first projection _{w_name}': loss.detach().cpu()}) # Step optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') # Normalize noise. with torch.no_grad(): for buf in noise_bufs.values(): buf -= buf.mean() buf *= buf.square().mean().rsqrt() del G return w_opt ```

I got errors as shown below:

Computing W midpoint and stddev using 600 samples...
  0%|          | 0/1000 [00:00<?, ?it/s]/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/nn/modules/module.py:1488: UserWarning: operator() profile_node %106 : int = prim::profile_ivalue(%104)
 does not have profile information (Triggered internally at /opt/conda/conda-bld/pytorch_1674202356920/work/torch/csrc/jit/codegen/cuda/graph_fuser.cpp:105.)
  return forward_call(*args, **kwargs)
  0%|          | 0/1000 [00:07<?, ?it/s]
Traceback (most recent call last):
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 396, in <module>
    main() # pylint: disable=no-value-for-parameter
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 391, in main
    launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 101, in launch_training
    subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 52, in subprocess_fn
    training_loop.training_loop(rank=rank, **c)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/training_loop.py", line 286, in training_loop
    loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 156, in accumulate_gradients
    cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 49, in pti_projector
    w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/projector/w_plus_projector.py", line 171, in project
    loss.backward()
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 275, in apply
    return user_fn(self, *args)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 146, in backward
    grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 508, in apply
    return super().apply(*args, **kwargs)
  File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 173, in forward
    return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
  File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_ops.py", line 499, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__convolution_backward)

seems like the error happens at loss.backward() and I checked most of the variables/loss/model to make sure they are on cuda:0. but I still got no luck to solve this. Do you know how to make loss backpropagate properly?

luchaoqi commented 1 year ago

okay, I'll just prerun PTI instead of calling PTI inside the code here to avoid this error.