So here is how I call w/w_plus projector (search for function pti_projector):
eg3d code
```
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""Loss functions."""
import numpy as np
import torch
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import upfirdn2d
from training.dual_discriminator import filtered_resizing
#----------------------------------------------------------------------------
class Loss:
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass
raise NotImplementedError()
#----------------------------------------------------------------------------
# ---------------------- project image into latent space --------------------- #
# modified code from https://github.com/oneThousand1000/EG3D-projector/tree/master/eg3d/projector
from training.projector import w_plus_projector, w_projector
from torchvision import transforms
import copy
def pti_projector(cur_G, cur_c, cur_image, device, latent_type='w_plus'):
# # put image back to cpu for transforms
# image = cur_image.cpu()
# # normalize image
# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
# std=[0.5, 0.5, 0.5])
# id_image = normalize(image)
# id_image = torch.squeeze((id_image + 1) / 2, 0)
id_image = cur_image.to(device)
# c = c.to(device)
c = torch.reshape(cur_c, (1, 25)).to(device) # 25 is the camera pose dimension 16 + 9
G = cur_G
if latent_type == 'w_plus':
w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600)
else:
w = w_projector.project(G, c, id_image, device=device, w_avg_samples=600)
print('w shape: ', w.shape)
return w
class StyleGAN2Loss(Loss):
def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased'):
super().__init__()
self.device = device
self.G = G
self.D = D
self.augment_pipe = augment_pipe
self.r1_gamma = r1_gamma
self.style_mixing_prob = style_mixing_prob
self.pl_weight = pl_weight
self.pl_batch_shrink = pl_batch_shrink
self.pl_decay = pl_decay
self.pl_no_weight_grad = pl_no_weight_grad
self.pl_mean = torch.zeros([], device=device)
self.blur_init_sigma = blur_init_sigma
self.blur_fade_kimg = blur_fade_kimg
self.r1_gamma_init = r1_gamma_init
self.r1_gamma_fade_kimg = r1_gamma_fade_kimg
self.neural_rendering_resolution_initial = neural_rendering_resolution_initial
self.neural_rendering_resolution_final = neural_rendering_resolution_final
self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg
self.gpc_reg_fade_kimg = gpc_reg_fade_kimg
self.gpc_reg_prob = gpc_reg_prob
self.dual_discrimination = dual_discrimination
self.filter_mode = filter_mode
self.resample_filter = upfirdn2d.setup_filter([1,3,3,1], device=device)
self.blur_raw_target = True
assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1)
def run_G(self, z, c, swapping_prob, neural_rendering_resolution, update_emas=False):
if swapping_prob is not None:
c_swapped = torch.roll(c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand((c.shape[0], 1), device=c.device) < swapping_prob, c_swapped, c)
else:
c_gen_conditioning = torch.zeros_like(c)
ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas)
if self.style_mixing_prob > 0:
with torch.autograd.profiler.record_function('style_mixing'):
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
gen_output = self.G.synthesis(ws, c, neural_rendering_resolution=neural_rendering_resolution, update_emas=update_emas)
return gen_output, ws
def run_D(self, img, c, blur_sigma=0, blur_sigma_raw=0, update_emas=False):
blur_size = np.floor(blur_sigma * 3)
if blur_size > 0:
with torch.autograd.profiler.record_function('blur'):
f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div(blur_sigma).square().neg().exp2()
img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum())
if self.augment_pipe is not None:
augmented_pair = self.augment_pipe(torch.cat([img['image'],
torch.nn.functional.interpolate(img['image_raw'], size=img['image'].shape[2:], mode='bilinear', antialias=True)],
dim=1))
img['image'] = augmented_pair[:, :img['image'].shape[1]]
img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], size=img['image_raw'].shape[2:], mode='bilinear', antialias=True)
logits = self.D(img, c, update_emas=update_emas)
return logits
def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg):
assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth']
if self.G.rendering_kwargs.get('density_reg', 0) == 0:
phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase)
if self.r1_gamma == 0:
phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase)
blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0
r1_gamma = self.r1_gamma
alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1
swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None
if self.neural_rendering_resolution_final is not None:
alpha = min(cur_nimg / (self.neural_rendering_resolution_fade_kimg * 1e3), 1)
neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * (1 - alpha) + self.neural_rendering_resolution_final * alpha))
else:
neural_rendering_resolution = self.neural_rendering_resolution_initial
real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, filter_mode=self.filter_mode)
if self.blur_raw_target:
blur_size = np.floor(blur_sigma * 3)
if blur_size > 0:
f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div(blur_sigma).square().neg().exp2()
real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum())
real_img = {'image': real_img, 'image_raw': real_img_raw}
# run PTI to get w/w_plus latent codes for real images
# print(real_img.shape, real_c.shape, gen_z.shape, gen_c.shape)
# torch.Size([8, 3, 512, 512]) torch.Size([8, 25]) torch.Size([8, 512]) torch.Size([8, 25])
# convert gen_z to real_z
batch_size = real_img['image'].shape[0]
real_z = []
for i in range(batch_size):
cur_img = real_img['image'][i]
cur_c = real_c[i]
cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device)
real_z.append(cur_z)
real_z = torch.stack(real_z)
print('real_z', real_z.shape)
# Gmain: Maximize logits for generated images.
if phase in ['Gmain', 'Gboth']:
with torch.autograd.profiler.record_function('Gmain_forward'):
gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution)
gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma)
training_stats.report('Loss/scores/fake', gen_logits)
training_stats.report('Loss/signs/fake', gen_logits.sign())
loss_Gmain = torch.nn.functional.softplus(-gen_logits)
training_stats.report('Loss/G/loss', loss_Gmain)
with torch.autograd.profiler.record_function('Gmain_backward'):
loss_Gmain.mean().mul(gain).backward()
# Density Regularization
if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'l1':
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
if self.style_mixing_prob > 0:
with torch.autograd.profiler.record_function('style_mixing'):
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * self.G.rendering_kwargs['density_reg_p_dist']
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
TVloss.mul(gain).backward()
# Alternative density regularization
if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-detach':
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front
perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
monotonic_loss = torch.relu(sigma_initial.detach() - sigma_perturbed).mean() * 10
monotonic_loss.mul(gain).backward()
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
if self.style_mixing_prob > 0:
with torch.autograd.profiler.record_function('style_mixing'):
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp']
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
TVloss.mul(gain).backward()
# Alternative density regularization
if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-fixed':
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front
perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
monotonic_loss = torch.relu(sigma_initial - sigma_perturbed).mean() * 10
monotonic_loss.mul(gain).backward()
if swapping_prob is not None:
c_swapped = torch.roll(gen_c.clone(), 1, 0)
c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c)
else:
c_gen_conditioning = torch.zeros_like(gen_c)
ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False)
if self.style_mixing_prob > 0:
with torch.autograd.profiler.record_function('style_mixing'):
cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1])
cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1]))
ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:]
initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1
perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp']
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1)
sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma']
sigma_initial = sigma[:, :sigma.shape[1]//2]
sigma_perturbed = sigma[:, sigma.shape[1]//2:]
TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg']
TVloss.mul(gain).backward()
# Dmain: Minimize logits for generated images.
loss_Dgen = 0
if phase in ['Dmain', 'Dboth']:
with torch.autograd.profiler.record_function('Dgen_forward'):
gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution, update_emas=True)
gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True)
training_stats.report('Loss/scores/fake', gen_logits)
training_stats.report('Loss/signs/fake', gen_logits.sign())
loss_Dgen = torch.nn.functional.softplus(gen_logits)
with torch.autograd.profiler.record_function('Dgen_backward'):
loss_Dgen.mean().mul(gain).backward()
# Dmain: Maximize logits for real images.
# Dr1: Apply R1 regularization.
if phase in ['Dmain', 'Dreg', 'Dboth']:
name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1'
with torch.autograd.profiler.record_function(name + '_forward'):
real_img_tmp_image = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(phase in ['Dreg', 'Dboth'])
real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw}
real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma)
training_stats.report('Loss/scores/real', real_logits)
training_stats.report('Loss/signs/real', real_logits.sign())
loss_Dreal = 0
if phase in ['Dmain', 'Dboth']:
loss_Dreal = torch.nn.functional.softplus(-real_logits)
training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal)
loss_Dr1 = 0
if phase in ['Dreg', 'Dboth']:
if self.dual_discrimination:
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image'], real_img_tmp['image_raw']], create_graph=True, only_inputs=True)
r1_grads_image = r1_grads[0]
r1_grads_image_raw = r1_grads[1]
r1_penalty = r1_grads_image.square().sum([1,2,3]) + r1_grads_image_raw.square().sum([1,2,3])
else: # single discrimination
with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients():
r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image']], create_graph=True, only_inputs=True)
r1_grads_image = r1_grads[0]
r1_penalty = r1_grads_image.square().sum([1,2,3])
loss_Dr1 = r1_penalty * (r1_gamma / 2)
training_stats.report('Loss/r1_penalty', r1_penalty)
training_stats.report('Loss/D/reg', loss_Dr1)
with torch.autograd.profiler.record_function(name + '_backward'):
(loss_Dreal + loss_Dr1).mean().mul(gain).backward()
#----------------------------------------------------------------------------
```
and here is the modified projector scripts:
w_plus_projector.py
```
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Project given image to the latent space of pretrained network pickle."""
import copy
import os
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import dnnlib
import PIL
from camera_utils import LookAtPoseSampler
def project(
G,
c,
# outdir,
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
*,
num_steps=1000,
w_avg_samples=10000,
initial_learning_rate=0.01,
initial_noise_factor=0.05,
lr_rampdown_length=0.25,
lr_rampup_length=0.05,
noise_ramp_length=0.75,
regularize_noise_weight=1e5,
verbose=False,
device: torch.device,
initial_w=None,
image_log_step=100,
# w_name: str
):
# os.makedirs(f'{outdir}/{w_name}_w_plus', exist_ok=True)
# outdir = f'{outdir}/{w_name}_w_plus'
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
def logprint(*args):
if verbose:
print(*args)
G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore
# Compute w stats.
w_avg_path = './w_avg.npy'
w_std_path = './w_std.npy'
if (not os.path.exists(w_avg_path)) or (not os.path.exists(w_std_path)):
print(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
# c_samples = c.repeat(w_avg_samples, 1)
# use avg look at point
camera_lookat_point = torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device)
cam2world_pose = LookAtPoseSampler.sample(3.14 / 2, 3.14 / 2, camera_lookat_point,
radius=G.rendering_kwargs['avg_camera_radius'], device=device)
focal_length = 4.2647 # FFHQ's FOV
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1)
c_samples = c_samples.repeat(w_avg_samples, 1)
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C]
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
# print('save w_avg to ./w_avg.npy')
# np.save('./w_avg.npy',w_avg)
w_avg_tensor = torch.from_numpy(w_avg).cuda()
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
# np.save(w_avg_path, w_avg)
# np.save(w_std_path, w_std)
else:
# w_avg = np.load(w_avg_path)
# w_std = np.load(w_std_path)
raise Exception(' ')
# z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
# c_samples = c.repeat(w_avg_samples, 1)
# w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C]
# w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
# w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
# w_avg_tensor = torch.from_numpy(w_avg).cuda()
# w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
start_w = initial_w if initial_w is not None else w_avg
# Setup noise inputs.
noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name}
# Load VGG16 feature detector.
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
# url = './networks/vgg16.pt'
with dnnlib.util.open_url(url) as f:
vgg16 = torch.jit.load(f, map_location=device).eval().to(device)
# Features for target image.
target_images = target.unsqueeze(0).to(device).to(torch.float32)
if target_images.shape[2] > 256:
target_images = F.interpolate(target_images, size=(256, 256), mode='area')
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1)
w_opt = torch.tensor(start_w, dtype=torch.float32, device=device,
requires_grad=True) # pylint: disable=not-callable
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999),
lr=0.1)
# Init noise.
for buf in noise_bufs.values():
buf[:] = torch.randn_like(buf)
buf.requires_grad = True
for step in tqdm(range(num_steps)):
# Learning rate schedule.
t = step / num_steps
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
lr = initial_learning_rate * lr_ramp
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Synth images from opt_w.
w_noise = torch.randn_like(w_opt) * w_noise_scale
ws = (w_opt + w_noise)
synth_images = G.synthesis(ws,c, noise_mode='const')['image']
# if step % image_log_step == 0:
# with torch.no_grad():
# vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
# PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}.png')
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
synth_images = (synth_images + 1) * (255 / 2)
if synth_images.shape[2] > 256:
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
# Features for synth images.
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
dist = (target_features - synth_features).square().sum()
# Noise regularization.
reg_loss = 0.0
for v in noise_bufs.values():
noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
while True:
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
loss = dist + reg_loss * regularize_noise_weight
# if step % 10 == 0:
# with torch.no_grad():
# print({f'step {step}, first projection _{w_name}': loss.detach().cpu()})
# Step
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
# Normalize noise.
with torch.no_grad():
for buf in noise_bufs.values():
buf -= buf.mean()
buf *= buf.square().mean().rsqrt()
del G
return w_opt
```
I got errors as shown below:
Computing W midpoint and stddev using 600 samples...
0%| | 0/1000 [00:00<?, ?it/s]/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/nn/modules/module.py:1488: UserWarning: operator() profile_node %106 : int = prim::profile_ivalue(%104)
does not have profile information (Triggered internally at /opt/conda/conda-bld/pytorch_1674202356920/work/torch/csrc/jit/codegen/cuda/graph_fuser.cpp:105.)
return forward_call(*args, **kwargs)
0%| | 0/1000 [00:07<?, ?it/s]
Traceback (most recent call last):
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 396, in <module>
main() # pylint: disable=no-value-for-parameter
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
return self.main(*args, **kwargs)
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1053, in main
rv = self.invoke(ctx)
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/click/core.py", line 754, in invoke
return __callback(*args, **kwargs)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 391, in main
launch_training(c=c, desc=desc, outdir=opts.outdir, dry_run=opts.dry_run)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 101, in launch_training
subprocess_fn(rank=0, c=c, temp_dir=temp_dir)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/train.py", line 52, in subprocess_fn
training_loop.training_loop(rank=rank, **c)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/training_loop.py", line 286, in training_loop
loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, gain=phase.interval, cur_nimg=cur_nimg)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 156, in accumulate_gradients
cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/loss.py", line 49, in pti_projector
w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/training/projector/w_plus_projector.py", line 171, in project
loss.backward()
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 275, in apply
return user_fn(self, *args)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 146, in backward
grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/autograd/function.py", line 508, in apply
return super().apply(*args, **kwargs)
File "/playpen-nas-ssd/luchao/projects/eg3d/eg3d/torch_utils/ops/conv2d_gradfix.py", line 173, in forward
return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
File "/playpen-nas-ssd/luchao/software/miniconda3/envs/eg3d/lib/python3.9/site-packages/torch/_ops.py", line 499, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument weight in method wrapper_CUDA__convolution_backward)
seems like the error happens at loss.backward() and I checked most of the variables/loss/model to make sure they are on cuda:0. but I still got no luck to solve this. Do you know how to make loss backpropagate properly?
Hello, I was trying to implement PTI into eg3d/loss.py at main · NVlabs/eg3d but I got some problems when calling the PTI/training/projectors at main · danielroich/PTI
So here is how I call w/w_plus projector (search for function
pti_projector
):eg3d code
``` # SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. """Loss functions.""" import numpy as np import torch from torch_utils import training_stats from torch_utils.ops import conv2d_gradfix from torch_utils.ops import upfirdn2d from training.dual_discriminator import filtered_resizing #---------------------------------------------------------------------------- class Loss: def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): # to be overridden by subclass raise NotImplementedError() #---------------------------------------------------------------------------- # ---------------------- project image into latent space --------------------- # # modified code from https://github.com/oneThousand1000/EG3D-projector/tree/master/eg3d/projector from training.projector import w_plus_projector, w_projector from torchvision import transforms import copy def pti_projector(cur_G, cur_c, cur_image, device, latent_type='w_plus'): # # put image back to cpu for transforms # image = cur_image.cpu() # # normalize image # normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], # std=[0.5, 0.5, 0.5]) # id_image = normalize(image) # id_image = torch.squeeze((id_image + 1) / 2, 0) id_image = cur_image.to(device) # c = c.to(device) c = torch.reshape(cur_c, (1, 25)).to(device) # 25 is the camera pose dimension 16 + 9 G = cur_G if latent_type == 'w_plus': w = w_plus_projector.project(G, c, id_image, device=device, w_avg_samples=600) else: w = w_projector.project(G, c, id_image, device=device, w_avg_samples=600) print('w shape: ', w.shape) return w class StyleGAN2Loss(Loss): def __init__(self, device, G, D, augment_pipe=None, r1_gamma=10, style_mixing_prob=0, pl_weight=0, pl_batch_shrink=2, pl_decay=0.01, pl_no_weight_grad=False, blur_init_sigma=0, blur_fade_kimg=0, r1_gamma_init=0, r1_gamma_fade_kimg=0, neural_rendering_resolution_initial=64, neural_rendering_resolution_final=None, neural_rendering_resolution_fade_kimg=0, gpc_reg_fade_kimg=1000, gpc_reg_prob=None, dual_discrimination=False, filter_mode='antialiased'): super().__init__() self.device = device self.G = G self.D = D self.augment_pipe = augment_pipe self.r1_gamma = r1_gamma self.style_mixing_prob = style_mixing_prob self.pl_weight = pl_weight self.pl_batch_shrink = pl_batch_shrink self.pl_decay = pl_decay self.pl_no_weight_grad = pl_no_weight_grad self.pl_mean = torch.zeros([], device=device) self.blur_init_sigma = blur_init_sigma self.blur_fade_kimg = blur_fade_kimg self.r1_gamma_init = r1_gamma_init self.r1_gamma_fade_kimg = r1_gamma_fade_kimg self.neural_rendering_resolution_initial = neural_rendering_resolution_initial self.neural_rendering_resolution_final = neural_rendering_resolution_final self.neural_rendering_resolution_fade_kimg = neural_rendering_resolution_fade_kimg self.gpc_reg_fade_kimg = gpc_reg_fade_kimg self.gpc_reg_prob = gpc_reg_prob self.dual_discrimination = dual_discrimination self.filter_mode = filter_mode self.resample_filter = upfirdn2d.setup_filter([1,3,3,1], device=device) self.blur_raw_target = True assert self.gpc_reg_prob is None or (0 <= self.gpc_reg_prob <= 1) def run_G(self, z, c, swapping_prob, neural_rendering_resolution, update_emas=False): if swapping_prob is not None: c_swapped = torch.roll(c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand((c.shape[0], 1), device=c.device) < swapping_prob, c_swapped, c) else: c_gen_conditioning = torch.zeros_like(c) ws = self.G.mapping(z, c_gen_conditioning, update_emas=update_emas) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] gen_output = self.G.synthesis(ws, c, neural_rendering_resolution=neural_rendering_resolution, update_emas=update_emas) return gen_output, ws def run_D(self, img, c, blur_sigma=0, blur_sigma_raw=0, update_emas=False): blur_size = np.floor(blur_sigma * 3) if blur_size > 0: with torch.autograd.profiler.record_function('blur'): f = torch.arange(-blur_size, blur_size + 1, device=img['image'].device).div(blur_sigma).square().neg().exp2() img['image'] = upfirdn2d.filter2d(img['image'], f / f.sum()) if self.augment_pipe is not None: augmented_pair = self.augment_pipe(torch.cat([img['image'], torch.nn.functional.interpolate(img['image_raw'], size=img['image'].shape[2:], mode='bilinear', antialias=True)], dim=1)) img['image'] = augmented_pair[:, :img['image'].shape[1]] img['image_raw'] = torch.nn.functional.interpolate(augmented_pair[:, img['image'].shape[1]:], size=img['image_raw'].shape[2:], mode='bilinear', antialias=True) logits = self.D(img, c, update_emas=update_emas) return logits def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, gain, cur_nimg): assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] if self.G.rendering_kwargs.get('density_reg', 0) == 0: phase = {'Greg': 'none', 'Gboth': 'Gmain'}.get(phase, phase) if self.r1_gamma == 0: phase = {'Dreg': 'none', 'Dboth': 'Dmain'}.get(phase, phase) blur_sigma = max(1 - cur_nimg / (self.blur_fade_kimg * 1e3), 0) * self.blur_init_sigma if self.blur_fade_kimg > 0 else 0 r1_gamma = self.r1_gamma alpha = min(cur_nimg / (self.gpc_reg_fade_kimg * 1e3), 1) if self.gpc_reg_fade_kimg > 0 else 1 swapping_prob = (1 - alpha) * 1 + alpha * self.gpc_reg_prob if self.gpc_reg_prob is not None else None if self.neural_rendering_resolution_final is not None: alpha = min(cur_nimg / (self.neural_rendering_resolution_fade_kimg * 1e3), 1) neural_rendering_resolution = int(np.rint(self.neural_rendering_resolution_initial * (1 - alpha) + self.neural_rendering_resolution_final * alpha)) else: neural_rendering_resolution = self.neural_rendering_resolution_initial real_img_raw = filtered_resizing(real_img, size=neural_rendering_resolution, f=self.resample_filter, filter_mode=self.filter_mode) if self.blur_raw_target: blur_size = np.floor(blur_sigma * 3) if blur_size > 0: f = torch.arange(-blur_size, blur_size + 1, device=real_img_raw.device).div(blur_sigma).square().neg().exp2() real_img_raw = upfirdn2d.filter2d(real_img_raw, f / f.sum()) real_img = {'image': real_img, 'image_raw': real_img_raw} # run PTI to get w/w_plus latent codes for real images # print(real_img.shape, real_c.shape, gen_z.shape, gen_c.shape) # torch.Size([8, 3, 512, 512]) torch.Size([8, 25]) torch.Size([8, 512]) torch.Size([8, 25]) # convert gen_z to real_z batch_size = real_img['image'].shape[0] real_z = [] for i in range(batch_size): cur_img = real_img['image'][i] cur_c = real_c[i] cur_z = pti_projector(self.G, cur_c, cur_img, device=self.device) real_z.append(cur_z) real_z = torch.stack(real_z) print('real_z', real_z.shape) # Gmain: Maximize logits for generated images. if phase in ['Gmain', 'Gboth']: with torch.autograd.profiler.record_function('Gmain_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution) gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Gmain = torch.nn.functional.softplus(-gen_logits) training_stats.report('Loss/G/loss', loss_Gmain) with torch.autograd.profiler.record_function('Gmain_backward'): loss_Gmain.mean().mul(gain).backward() # Density Regularization if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'l1': if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * self.G.rendering_kwargs['density_reg_p_dist'] all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg'] TVloss.mul(gain).backward() # Alternative density regularization if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-detach': if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] monotonic_loss = torch.relu(sigma_initial.detach() - sigma_perturbed).mean() * 10 monotonic_loss.mul(gain).backward() if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp'] all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg'] TVloss.mul(gain).backward() # Alternative density regularization if phase in ['Greg', 'Gboth'] and self.G.rendering_kwargs.get('density_reg', 0) > 0 and self.G.rendering_kwargs['reg_type'] == 'monotonic-fixed': if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) initial_coordinates = torch.rand((ws.shape[0], 2000, 3), device=ws.device) * 2 - 1 # Front perturbed_coordinates = initial_coordinates + torch.tensor([0, 0, -1], device=ws.device) * (1/256) * self.G.rendering_kwargs['box_warp'] # Behind all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] monotonic_loss = torch.relu(sigma_initial - sigma_perturbed).mean() * 10 monotonic_loss.mul(gain).backward() if swapping_prob is not None: c_swapped = torch.roll(gen_c.clone(), 1, 0) c_gen_conditioning = torch.where(torch.rand([], device=gen_c.device) < swapping_prob, c_swapped, gen_c) else: c_gen_conditioning = torch.zeros_like(gen_c) ws = self.G.mapping(gen_z, c_gen_conditioning, update_emas=False) if self.style_mixing_prob > 0: with torch.autograd.profiler.record_function('style_mixing'): cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) ws[:, cutoff:] = self.G.mapping(torch.randn_like(z), c, update_emas=False)[:, cutoff:] initial_coordinates = torch.rand((ws.shape[0], 1000, 3), device=ws.device) * 2 - 1 perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * (1/256) * self.G.rendering_kwargs['box_warp'] all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) sigma = self.G.sample_mixed(all_coordinates, torch.randn_like(all_coordinates), ws, update_emas=False)['sigma'] sigma_initial = sigma[:, :sigma.shape[1]//2] sigma_perturbed = sigma[:, sigma.shape[1]//2:] TVloss = torch.nn.functional.l1_loss(sigma_initial, sigma_perturbed) * self.G.rendering_kwargs['density_reg'] TVloss.mul(gain).backward() # Dmain: Minimize logits for generated images. loss_Dgen = 0 if phase in ['Dmain', 'Dboth']: with torch.autograd.profiler.record_function('Dgen_forward'): gen_img, _gen_ws = self.run_G(gen_z, gen_c, swapping_prob=swapping_prob, neural_rendering_resolution=neural_rendering_resolution, update_emas=True) gen_logits = self.run_D(gen_img, gen_c, blur_sigma=blur_sigma, update_emas=True) training_stats.report('Loss/scores/fake', gen_logits) training_stats.report('Loss/signs/fake', gen_logits.sign()) loss_Dgen = torch.nn.functional.softplus(gen_logits) with torch.autograd.profiler.record_function('Dgen_backward'): loss_Dgen.mean().mul(gain).backward() # Dmain: Maximize logits for real images. # Dr1: Apply R1 regularization. if phase in ['Dmain', 'Dreg', 'Dboth']: name = 'Dreal' if phase == 'Dmain' else 'Dr1' if phase == 'Dreg' else 'Dreal_Dr1' with torch.autograd.profiler.record_function(name + '_forward'): real_img_tmp_image = real_img['image'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) real_img_tmp_image_raw = real_img['image_raw'].detach().requires_grad_(phase in ['Dreg', 'Dboth']) real_img_tmp = {'image': real_img_tmp_image, 'image_raw': real_img_tmp_image_raw} real_logits = self.run_D(real_img_tmp, real_c, blur_sigma=blur_sigma) training_stats.report('Loss/scores/real', real_logits) training_stats.report('Loss/signs/real', real_logits.sign()) loss_Dreal = 0 if phase in ['Dmain', 'Dboth']: loss_Dreal = torch.nn.functional.softplus(-real_logits) training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) loss_Dr1 = 0 if phase in ['Dreg', 'Dboth']: if self.dual_discrimination: with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image'], real_img_tmp['image_raw']], create_graph=True, only_inputs=True) r1_grads_image = r1_grads[0] r1_grads_image_raw = r1_grads[1] r1_penalty = r1_grads_image.square().sum([1,2,3]) + r1_grads_image_raw.square().sum([1,2,3]) else: # single discrimination with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp['image']], create_graph=True, only_inputs=True) r1_grads_image = r1_grads[0] r1_penalty = r1_grads_image.square().sum([1,2,3]) loss_Dr1 = r1_penalty * (r1_gamma / 2) training_stats.report('Loss/r1_penalty', r1_penalty) training_stats.report('Loss/D/reg', loss_Dr1) with torch.autograd.profiler.record_function(name + '_backward'): (loss_Dreal + loss_Dr1).mean().mul(gain).backward() #---------------------------------------------------------------------------- ```and here is the modified projector scripts:
w_plus_projector.py
``` # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Project given image to the latent space of pretrained network pickle.""" import copy import os import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm import dnnlib import PIL from camera_utils import LookAtPoseSampler def project( G, c, # outdir, target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution *, num_steps=1000, w_avg_samples=10000, initial_learning_rate=0.01, initial_noise_factor=0.05, lr_rampdown_length=0.25, lr_rampup_length=0.05, noise_ramp_length=0.75, regularize_noise_weight=1e5, verbose=False, device: torch.device, initial_w=None, image_log_step=100, # w_name: str ): # os.makedirs(f'{outdir}/{w_name}_w_plus', exist_ok=True) # outdir = f'{outdir}/{w_name}_w_plus' assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) def logprint(*args): if verbose: print(*args) G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() # type: ignore # Compute w stats. w_avg_path = './w_avg.npy' w_std_path = './w_std.npy' if (not os.path.exists(w_avg_path)) or (not os.path.exists(w_std_path)): print(f'Computing W midpoint and stddev using {w_avg_samples} samples...') z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) # c_samples = c.repeat(w_avg_samples, 1) # use avg look at point camera_lookat_point = torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device) cam2world_pose = LookAtPoseSampler.sample(3.14 / 2, 3.14 / 2, camera_lookat_point, radius=G.rendering_kwargs['avg_camera_radius'], device=device) focal_length = 4.2647 # FFHQ's FOV intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) c_samples = c_samples.repeat(w_avg_samples, 1) w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C] w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] # print('save w_avg to ./w_avg.npy') # np.save('./w_avg.npy',w_avg) w_avg_tensor = torch.from_numpy(w_avg).cuda() w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 # np.save(w_avg_path, w_avg) # np.save(w_std_path, w_std) else: # w_avg = np.load(w_avg_path) # w_std = np.load(w_std_path) raise Exception(' ') # z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) # c_samples = c.repeat(w_avg_samples, 1) # w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) # [N, L, C] # w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] # w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] # w_avg_tensor = torch.from_numpy(w_avg).cuda() # w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 start_w = initial_w if initial_w is not None else w_avg # Setup noise inputs. noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name} # Load VGG16 feature detector. url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' # url = './networks/vgg16.pt' with dnnlib.util.open_url(url) as f: vgg16 = torch.jit.load(f, map_location=device).eval().to(device) # Features for target image. target_images = target.unsqueeze(0).to(device).to(torch.float32) if target_images.shape[2] > 256: target_images = F.interpolate(target_images, size=(256, 256), mode='area') target_features = vgg16(target_images, resize_images=False, return_lpips=True) start_w = np.repeat(start_w, G.backbone.mapping.num_ws, axis=1) w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=0.1) # Init noise. for buf in noise_bufs.values(): buf[:] = torch.randn_like(buf) buf.requires_grad = True for step in tqdm(range(num_steps)): # Learning rate schedule. t = step / num_steps w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) lr = initial_learning_rate * lr_ramp for param_group in optimizer.param_groups: param_group['lr'] = lr # Synth images from opt_w. w_noise = torch.randn_like(w_opt) * w_noise_scale ws = (w_opt + w_noise) synth_images = G.synthesis(ws,c, noise_mode='const')['image'] # if step % image_log_step == 0: # with torch.no_grad(): # vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) # PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}.png') # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. synth_images = (synth_images + 1) * (255 / 2) if synth_images.shape[2] > 256: synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') # Features for synth images. synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) dist = (target_features - synth_features).square().sum() # Noise regularization. reg_loss = 0.0 for v in noise_bufs.values(): noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() while True: reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 if noise.shape[2] <= 8: break noise = F.avg_pool2d(noise, kernel_size=2) loss = dist + reg_loss * regularize_noise_weight # if step % 10 == 0: # with torch.no_grad(): # print({f'step {step}, first projection _{w_name}': loss.detach().cpu()}) # Step optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') # Normalize noise. with torch.no_grad(): for buf in noise_bufs.values(): buf -= buf.mean() buf *= buf.square().mean().rsqrt() del G return w_opt ```I got errors as shown below:
seems like the error happens at
loss.backward()
and I checked most of the variables/loss/model to make sure they are oncuda:0
. but I still got no luck to solve this. Do you know how to make loss backpropagate properly?