hongfz16 / AvatarCLIP

[SIGGRAPH 2022 Journal Track] AvatarCLIP: Zero-Shot Text-Driven Generation and Animation of 3D Avatars
https://hongfz16.github.io/projects/AvatarCLIP.html
Other
1.06k stars 93 forks source link

AvatarCLIP/AvatarAnimate #20

Open 1390806607 opened 9 months ago

1390806607 commented 9 months ago

class PoseOptimizer(BasePoseGenerator): """ This method will directly optimize SMPL theta with the guidance from CLIP """ def init(self, optim_name: Optional[str] = 'Adam', optim_cfg: Optional[dict] = {'lr': 0.01}, num_iteration: Optional[int] = 500, kwargs): super().init(kwargs) self.optim_name = optim_name self.optim_cfg = optim_cfg self.num_iteration = num_iteration

def get_pose(self, text_feature: Tensor) -> Tensor:
    pose = nn.Parameter(torch.randn(63, requires_grad=True).to(self.device))
    cls = getattr(torch.optim, self.optim_name)
    optimizer = cls([pose], **self.optim_cfg)
    print(optimizer)
    for i in tqdm(range(self.num_iteration)):
        new_pose = pose.to(self.device)
        clip_feature = self.get_pose_feature(new_pose).squeeze(0)
        loss = 1 - F.cosine_similarity(clip_feature, text_feature, dim=-1)
        # loss = loss.mean()
        torch.autograd.set_detect_anomaly(True)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return pose_padding(pose.data).to(self.device)

def get_topk_poses(self, text: str) -> Tensor:
    text_feature = self.get_text_feature(text)
    poses = [self.get_pose(text_feature) for _ in range(self.topk)]
    poses = self.sort_poses_by_score(text, poses)
    poses = torch.stack(poses, dim=0)
    return poses

root@autodl-container-60e5119152-54c37f51:~/autodl-tmp/AvatarCLIP/AvatarAnimate# python main.py --conf confs/pose_ablation/pose_optimizer/argue.conf 2023-10-28 14:55:47.830 | INFO | human_body_prior.tools.model_loader:load_model:97 - Loaded model in eval mode with trained weights: data/vposer/snapshots/V02_05_epoch=08_val_loss=0.03.ckpt Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 0.01 maximize: False weight_decay: 0 ) 0%| | 0/500 [00:00<?, ?it/s] Traceback (most recent call last): File "main.py", line 52, in main(args.conf) File "main.py", line 27, in main candidate_poses = pose_generator.get_topk_poses(text) File "/root/autodl-tmp/AvatarCLIP/AvatarAnimate/models/pose_generation.py", line 141, in get_topk_poses poses = [self.get_pose(textfeature) for in range(self.topk)] File "/root/autodl-tmp/AvatarCLIP/AvatarAnimate/models/pose_generation.py", line 141, in poses = [self.get_pose(textfeature) for in range(self.topk)] File "/root/autodl-tmp/AvatarCLIP/AvatarAnimate/models/pose_generation.py", line 135, in get_pose loss.backward() File "/root/miniconda3/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) File "/root/miniconda3/lib/python3.8/site-packages/torch/autograd/init.py", line 173, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

1390806607 commented 9 months ago

from abc import ABCMeta, abstractmethod from typing import Optional, Union, Tuple import torch from torch import Tensor import torch.nn as nn import torch.nn.functional as F from torch import distributions

import numpy as np import clip import smplx from tqdm import tqdm

from .render import render_one_batch from human_body_prior.tools.model_loader import load_model from human_body_prior.models.vposer_model import VPoser

def pose_padding(pose): assert pose.shape[-1] == 69 or pose.shape[-1] == 63 if pose.shape[-1] == 63: padded_zeros = torch.zeros_like(pose)[..., :6] pose = torch.cat((pose, padded_zeros), dim=-1) return pose

class BasePoseGenerator(nn.Module, metaclass=ABCMeta): """ Base class for pose generation """ def init(self, name: str, topk: Optional[int] = 5, smpl_path: Optional[str] = '../smpl_models', vposer_path: Optional[str] = 'data/vposer'): super().init() self.name = name self.topk = topk self.device = "cuda" if torch.cuda.is_available() else "cpu" assert self.device == "cuda" # neuralrender does not support inference on cpu self.clip, = clip.load('ViT-B/32', self.device) self.smpl = smplx.create(smplpath, 'smpl').to(self.device) self.vp, = load_model( vposer_path, model_code=VPoser, remove_words_in_model_weights='vp_model.', disable_grad=True) self.vp = self.vp.to(self.device) self.clip.eval() self.vp.eval()

@abstractmethod
def get_topk_poses(self,
                   text: str):
    raise NotImplementedError()

def get_text_feature(self, text: str) -> Tensor:
    text = clip.tokenize([text]).to(self.device)
    with torch.no_grad():
        text_features = self.clip.encode_text(text)
        text_feature = text_features[0]
    return text_feature

def get_pose_feature(self,
                     pose: Tensor,
                     angles: Optional[Union[Tuple[float], None]] = None) -> Tensor:

pose = pose_padding(pose)

    if len(pose.shape) == 1:
        pose = pose.unsqueeze(0)
    bs = pose.shape[0]
    # fix the orientation
    global_orient = torch.zeros(bs, 3).type_as(pose)
    global_orient[:, 0] = np.pi / 2
    output = self.smpl(
        body_pose=pose,
        global_orient=global_orient)
    v = output.vertices
    f = self.smpl.faces
    f = torch.from_numpy(f.astype(np.int32)).unsqueeze(0).repeat(bs, 1, 1).to(self.device)
    if angles is None:
        angles = (120, 150, 180, 210, 240)
    images = render_one_batch(v, f, angles, self.device)
    images = F.interpolate(images, size=224)
    mean = np.array([0.48145466, 0.4578275, 0.40821073])
    std = np.array([0.26862954, 0.26130258, 0.27577711])
    images = images - torch.from_numpy(mean).reshape(1, 3, 1, 1).to(self.device)
    images = images - torch.from_numpy(std).reshape(1, 3, 1, 1).to(self.device)
    num_camera = len(angles)
    image_embed = self.clip.encode_image(images).float().view(num_camera, -1, 512)
    return image_embed.mean(0)

def calculate_pose_score(self, text: str, pose: Tensor) -> float:
    text_feature = self.get_text_feature(text).unsqueeze(0)
    pose_feature = self.get_pose_feature(pose)
    score = F.cosine_similarity(text_feature, pose_feature).item()
    return float(score)

def sort_poses_by_score(self, text, poses):
    poses.sort(key=lambda x: self.calculate_pose_score(text, x), reverse=True)
    return poses

class PoseOptimizer(BasePoseGenerator): """ This method will directly optimize SMPL theta with the guidance from CLIP """ def init(self, optim_name: Optional[str] = 'Adam', optim_cfg: Optional[dict] = {'lr': 0.01}, num_iteration: Optional[int] = 500, kwargs): super().init(kwargs) self.optim_name = optim_name self.optim_cfg = optim_cfg self.num_iteration = num_iteration

def get_pose(self, text_feature: Tensor) -> Tensor:
    pose = nn.Parameter(torch.randn(69))
    cls = getattr(torch.optim, self.optim_name)
    optimizer = cls([pose], **self.optim_cfg)
    for i in tqdm(range(self.num_iteration)):
        new_pose = pose.to(self.device)
        clip_feature = self.get_pose_feature(new_pose).squeeze(0)
        loss = 1 - F.cosine_similarity(clip_feature, text_feature, dim=-1)
        loss = loss.mean()
        optimizer.zero_grad()
        torch.autograd.set_detect_anomaly(True)
        loss.backward()
        optimizer.step()
    return pose_padding(pose.data).to(self.device)

def get_topk_poses(self, text: str) -> Tensor:
    text_feature = self.get_text_feature(text)
    poses = [self.get_pose(text_feature) for _ in range(self.topk)]
    poses = self.sort_poses_by_score(text, poses)
    poses = torch.stack(poses, dim=0)
    return poses

class VPoserOptimizer(BasePoseGenerator): """ This method will optimize SMPL theta in the latent space of VPoser """ def init(self, optim_name: Optional[str] = 'Adam', optim_cfg: Optional[dict] = {'lr': 0.01}, num_iteration: Optional[int] = 500, kwargs): super().init(kwargs) self.optim_name = optim_name self.optim_cfg = optim_cfg self.num_iteration = num_iteration

def get_pose(self, text_feature: Tensor) -> Tensor:
    latent_code = nn.Parameter(torch.randn(32))
    cls = getattr(torch.optim, self.optim_name)
    optimizer = cls([latent_code], **self.optim_cfg)
    for i in tqdm(range(self.num_iteration)):
        new_latent_code = latent_code.to(self.device).unsqueeze(0)
        new_pose = self.vp.decode(new_latent_code)['pose_body']
        new_pose = new_pose.contiguous().view(-1)
        clip_feature = self.get_pose_feature(new_pose).squeeze(0)
        loss = 1 - F.cosine_similarity(clip_feature, text_feature, dim=-1)
        loss = loss.mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return pose_padding(new_pose.detach())

def get_topk_poses(self, text: str) -> Tensor:
    text_feature = self.get_text_feature(text)
    poses = [self.get_pose(text_feature) for _ in range(self.topk)]
    poses = self.sort_poses_by_score(text, poses)
    poses = torch.stack(poses, dim=0)
    return poses

class VPoserRealNVP(BasePoseGenerator): """ This method will generate SMPL theta from a pretrained conditional RealNVP. The code is adapted from https://github.com/senya-ashukha/real-nvp-pytorch

`dim` is the dimension of both input and output. (for vposer is 32)
`hdim` is the dimension of hidden layer.
"""
def __init__(self,
             dim: Optional[int] = 32,
             hdim: Optional[int] = 256,
             num_block: Optional[int] = 8,
             num_sample: Optional[int] = 10,
             num_batch: Optional[int] = 50,
             ckpt_path: Optional[str] = 'data/pose_realnvp.pth',
             **kwargs):

    super().__init__(**kwargs)
    self.prior = distributions.MultivariateNormal(
        torch.zeros(dim).to(self.device), torch.eye(dim).to(self.device))
    self.dim = dim
    self.num_sample = num_sample
    self.num_batch = num_batch
    self.s = nn.ModuleList()
    self.t = nn.ModuleList()
    self.num_block = num_block
    mask = torch.randn(num_block, 1, dim)
    mask[mask > 0] = 1
    mask[mask < 0] = 0
    self.register_buffer('mask', mask)
    for i in range(num_block):
        self.s.append(
            nn.Sequential(
                nn.Linear(dim + 512, hdim),  # concat clip feature
                nn.LeakyReLU(),
                nn.Linear(hdim, hdim),
                nn.LeakyReLU(),
                nn.Linear(hdim, dim),
                nn.Tanh()
            )
        )
        self.t.append(
            nn.Sequential(
                nn.Linear(dim + 512, hdim),
                nn.LeakyReLU(),
                nn.Linear(hdim, hdim),
                nn.LeakyReLU(),
                nn.Linear(hdim, dim)
            )
        )
    data = torch.load(ckpt_path, map_location='cpu')
    self.load_state_dict(data['state_dict'], strict=False)
    self.s = self.s.to(self.device)
    self.t = self.t.to(self.device)
    self.mask = self.mask.to(self.device)
    self.eval()

def decode(self, x: Tensor, features: Tensor) -> Tensor:
    for i in range(len(self.t)):
        x_ = x * self.mask[i]
        trans = torch.cat((x_, features), dim=-1)
        s = self.s[i](trans) * (1 - self.mask[i])
        t = self.t[i](trans) * (1 - self.mask[i])
        x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
    return x

def sample(self, bs: int, features: Tensor) -> Tensor:
    z = self.prior.sample((bs, 1)).squeeze(1).to(self.device)
    if len(features.shape) == 1:
        features = features.unsqueeze(0)
    features = features.repeat(bs, 1)
    x = self.decode(z, features)
    return x

def encode(self, x: Tensor, features: Tensor):
    """
    This is only used during training.
    """
    log_det = torch.zeros(x.shape[0]).type_as(x)
    z = x
    for i in reversed(range(self.num_block)):
        z_ = self.mask[i] * z
        trans = torch.cat((z_, features), dim=-1)
        s = self.s[i](trans) * (1 - self.mask[i])
        t = self.t[i](trans) * (1 - self.mask[i])
        z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
        log_det -= s.sum(dim=1)
    return z, log_det

def get_pose(self, text_feature: Tensor) -> Tensor:
    text_feature = text_feature.unsqueeze(0)
    best_score = 0
    with torch.no_grad():
        for i in tqdm(range(self.num_batch)):
            latent_codes = self.sample(self.num_sample, text_feature)
            poses = self.vp.decode(latent_codes)['pose_body'].reshape(self.num_sample, -1)
            pose_feature = self.get_pose_feature(poses)
            score = F.cosine_similarity(pose_feature, text_feature)
            idx = torch.argmax(score)
            if score[idx] > best_score:
                best_pose = poses[idx]
                best_score = score[idx]
    return best_pose

def get_topk_poses(self, text: str) -> Tensor:
    text_feature = self.get_text_feature(text)
    poses = [self.get_pose(text_feature) for _ in range(self.topk)]
    poses = self.sort_poses_by_score(text, poses)
    poses = torch.stack(poses, dim=0)
    return poses

class VPoserCodebook(BasePoseGenerator): """ This method will find out the poses which are most similar with given text from a codebook. """ def init(self, codebook_path='data/codebook.pth', pre_topk=40, filter_threshold=0.07, kwargs): super().init(kwargs) data = torch.load(codebook_path) self.codebook = data['codebook'].to(self.device) self.codebook_embedding = data['codebook_embedding'].to(self.device) self.pre_topk = pre_topk self.filter_threshold = filter_threshold

def suppress_duplicated_poses(self, poses: Tensor, threshold: float) -> Tensor:
    new_poses = []
    for pose in poses:
        if len(new_poses) == 0:
            new_poses.append(pose)
        else:
            min_dis = 10
            for j in range(len(new_poses)):
                cur_dis = torch.abs(pose - new_poses[j]).mean()
                min_dis = min(cur_dis, min_dis)
            if min_dis > threshold:
                new_poses.append(pose)
    poses = torch.stack(new_poses, dim=0)
    return poses

def get_topk_poses(self, text: str) -> Tensor:
    with torch.no_grad():
        text_feature = self.get_text_feature(text).unsqueeze(0)
        score = F.cosine_similarity(
            self.codebook_embedding, text_feature).view(-1)
        _, indexs = torch.topk(score, self.pre_topk)
        latent_codes = self.codebook[indexs]
        poses = self.vp.decode(latent_codes)['pose_body'].reshape(self.pre_topk, -1)
        poses = self.suppress_duplicated_poses(poses, threshold=self.filter_threshold)
        poses = poses[:self.topk]
    return poses

python main.py --conf confs/pose_ablation/pose_optimizer/argue.conf in AvatarCLIP/AvatarAnimate, the following error was reported
image

1390806607 commented 9 months ago

The environment is as follows: antlr4-python3-runtime 4.9.3 anyio 3.7.1 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 asttokens 2.4.1 attrs 23.1.0 backcall 0.2.0 beautifulsoup4 4.12.2 bleach 6.0.0 body-visualizer 1.1.0 certifi 2022.12.7 cffi 1.15.1 chumpy 0.70 clip 1.0 colorama 0.4.6 comm 0.1.4 cycler 0.11.0 dataclasses 0.6 debugpy 1.7.0 decorator 5.1.1 defusedxml 0.7.1 dotmap 1.3.30 entrypoints 0.4 exceptiongroup 1.1.3 executing 2.0.0 fastjsonschema 2.18.1 fonttools 4.38.0 freetype-py 2.4.0 ftfy 6.1.1 future 0.18.3 human-body-prior 2.2.2.0 icecream 2.1.0 idna 3.4 imageio 2.31.2 importlib-metadata 6.7.0 importlib-resources 5.12.0 ipykernel 6.16.2 ipython 7.34.0 ipython-genutils 0.2.0 ipywidgets 8.1.1 jedi 0.19.1 Jinja2 3.1.2 jsonschema 4.17.3 jupyter_client 7.4.9 jupyter_core 4.12.0 jupyter-server 1.24.0 jupyterlab-pygments 0.2.2 jupyterlab-widgets 3.0.9 kiwisolver 1.4.5 loguru 0.7.2 MarkupSafe 2.1.3 matplotlib 3.5.3 matplotlib-inline 0.1.6 mistune 3.0.2 mkl-fft 1.3.1 mkl-random 1.2.2 mkl-service 2.4.0 nbclassic 1.0.0 nbclient 0.7.4 nbconvert 7.6.0 nbformat 5.8.0 nest-asyncio 1.5.8 networkx 2.6.3 neural-renderer-pytorch 1.1.3 notebook 6.5.6 notebook_shim 0.2.3 numpy 1.21.0 omegaconf 2.3.0 open3d-python 0.7.0.0 opencv-python 4.5.2.52 packaging 23.2 pandocfilters 1.5.0 parso 0.8.3 pexpect 4.8.0 pickleshare 0.7.5 Pillow 9.4.0 pip 22.3.1 pkgutil_resolve_name 1.3.10 prometheus-client 0.17.1 prompt-toolkit 3.0.39 psutil 5.9.6 ptyprocess 0.7.0 pycparser 2.21 pyglet 1.5.9 Pygments 2.16.1 pyhocon 0.3.57 PyMCubes 0.1.2 PyOpenGL 3.1.0 pyparsing 3.1.1 pyrender 0.1.45 pyrsistent 0.19.3 python-dateutil 2.8.2 PyWavelets 1.3.0 PyYAML 6.0.1 pyzmq 24.0.1 regex 2023.10.3 scikit-image 0.19.3 scipy 1.7.0 Send2Trash 1.8.2 setuptools 65.6.3 shapely 2.0.2 six 1.16.0 smplx 0.1.28 sniffio 1.3.0 soupsieve 2.4.1 terminado 0.17.1 tifffile 2021.11.2 tinycss2 1.2.1 torch 1.7.0 torchaudio 0.7.0a0+ac17b64 torchvision 0.8.0 tornado 6.2 tqdm 4.50.2 traitlets 5.9.0 trimesh 3.9.8 typing_extensions 4.3.0 wcwidth 0.2.8 webencodings 0.5.1 websocket-client 1.6.1 wheel 0.38.4 widgetsnbextension 4.0.9 zipp 3.15.0

1390806607 commented 9 months ago

@hongfz16 @TianxingWu @mingyuan-zhang Could you take a look at it for me, please.Thank you very much

julyaugust12345 commented 7 months ago

@1390806607 Do you solve this problem?I meet the same problem. RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 6890]], which is output 0 of SelectBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True). It would be very appreciate if you help me.

Daffodily commented 6 months ago

@1390806607 Do you solve this problem?I meet the same problem. RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 6890]], which is output 0 of SelectBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True). It would be very appreciate if you help me.

I also encounter with the same problem. Have you solved it?

julyaugust12345 commented 6 months ago

@1390806607 Do you solve this problem?I meet the same problem. RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 6890]], which is output 0 of SelectBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True). It would be very appreciate if you help me.

I also encounter with the same problem. Have you solved it?

No.I don't solve it.Have you solved it?