liren2515 / GarmentRecovery

Code for "Garment Recovery with Shape and Deformation Priors", CVPR2024
35 stars 2 forks source link

About simulation #3

Open MontaEllis opened 2 months ago

MontaEllis commented 2 months ago

Very nice work! I've run the code and got amazing results! But I wonder how you do simulation, can you do me a favor provide some details of simulation?

liren2515 commented 2 months ago

Hello, thanks for your interest in our work.

For simulation, you can use any simulation software, such as blender (with smplx add-on) or marvelous designer. There are many tutorials online.

You can also use customized codes for this purpose. We simply utilized the physics-based loss (borrowed from SNUG to optimize the garment mesh given a sequence of underlying body. But this cannot produce results as good as those from simulation software.

MontaEllis commented 2 months ago

Thank you very much! I'm wondering if we should determine the pinned vertices ourselves?

liren2515 commented 2 months ago

Yes, you need to determine the vertices for pin by yourselves. For trousers and skirts, it is easy since you can use the boundary vertices around the waist for pinning. And the boundary vertices can be easily queried by finding the edges which only appear in a single face.

MontaEllis commented 1 month ago

I've run the code of Sung, but I can't produce the expected results. Could you do me a favor and open-source the simulation code?

liren2515 commented 1 month ago

It's better if you can use simulator (blender or mixamo). We just made a simple demo, so our code for animation is naive. There are many algorithms for cloth simulation in CG. You can alos find something implemented in C++.

The following is the piece of code that we use for simulation. Note that the code is incomplete. You need some modification such as providing the cloth/body mesh and the pose sequences.

import os, sys
import numpy as np
import trimesh
import torch

from utils.isp_cut import select_boundary, get_connected_paths_skirt
from smpl_pytorch.body_models import SMPL
from snug.snug_class import Body, Cloth_from_NP, Material
from snug.snug_helper import stretching_energy, bending_energy, gravitational_energy, inertial_term_sequence, collision_penalty

def update_waist(v_gar, v_body, f_body, idx_v_waist, closest_face_idx, v_barycentric):
    with torch.no_grad():
        tri_b =  v_body[f_body.reshape(-1)].reshape(-1, 3, 3)
        tri_waist = tri_b[closest_face_idx]
        v_waist = (tri_waist * v_barycentric[:, :, None]).sum(dim=-2)

        mask = torch.zeros(len(v_gar)).cuda()
        mask[idx_v_waist] = 1
        mask = 1 - mask

        v_update = torch.zeros_like(v_gar)
        v_update[idx_v_waist] = v_waist

    return mask.detach(), v_update.detach()

def EFT_seq(garment_batch0, garment_batch1, pose_t, beta, cloth, body, smpl_server, faces_body, idx_v_waist, closest_face_idx, v_barycentric, lr=1e-3, iters=500, time_step=1./30):

    beta = beta.repeat(len(pose_t), 1)
    faces_garment = cloth.f.cpu().numpy()

    verts_body_t, _ = infer_smpl(pose_t, beta, smpl_server)
    verts_body_t.requires_grad = False

    faces_body_cuda = torch.LongTensor(faces_body.astype(int)).cuda()
    closest_face_idx = torch.LongTensor(closest_face_idx).cuda()
    v_barycentric = torch.FloatTensor(v_barycentric).cuda()
    garment_current_record = [garment_batch0, garment_batch1]
    T = len(pose_t)
    v_update = (garment_current_record[-1] - garment_current_record[-2])/time_step
    v_update = v_update.detach()
    v_update.requires_grad = True
    optimizer = torch.optim.Adam([{'params': v_update, 'lr': lr}])
    for t in range(2, T):
        print('t:', t)
        with torch.no_grad():
            body.update_body(verts_body_t[[t]])
            vb = body.vb
            nb = body.nb
            v_prev = v_update.clone().detach() # = time_step * v_prev
            garment_prev = garment_current_record[-1].clone().detach()
            vb.requires_grad = False
            nb.requires_grad = False
            v_prev.requires_grad = False
            garment_prev.requires_grad = False

            mask, waist_update = update_waist(garment_prev.squeeze(), vb.squeeze(), faces_body_cuda, idx_v_waist, closest_face_idx, v_barycentric)

        for i in range(iters):
            garment_current = garment_prev + v_update*time_step
            garment_current = garment_current*mask[:,None] + waist_update
            v_diff = v_update - v_prev

            loss_strain = stretching_energy(garment_current, cloth)
            loss_bending = bending_energy(garment_current, cloth)#*5
            loss_gravity = gravitational_energy(garment_current, cloth.v_mass)#*2
            loss_collision = collision_penalty(garment_current, vb, nb, eps=5e-3)

            num = torch.einsum('bvi,bvi->bv', v_diff, cloth.v_mass[None, :, None] * v_diff)
            den = 2 * time_step
            loss_inertial = (num / den).sum()

            loss = loss_inertial + loss_strain + loss_bending + loss_gravity + loss_collision
            print('iter: %3d, loss: %0.4f, loss_inertial: %0.4f, loss_strain: %0.4f, loss_bending: %0.4f, loss_gravity: %0.4f, loss_collision: %0.4f'%(i, loss.item(), loss_inertial.item(), loss_strain.item(), loss_bending.item(), loss_gravity.item(), loss_collision.item()))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        with torch.no_grad():
            garment_current = (garment_prev + v_update*time_step).detach()
            garment_current = garment_current*mask[:,None] + waist_update
            garment_current_record.append(garment_current.detach())

            mesh_body = trimesh.Trimesh(verts_body_t[t].squeeze().cpu().numpy(), faces_body)
            mesh_gar = trimesh.Trimesh(garment_current_record[t].squeeze().cpu().numpy(), faces_garment)
            colors_g = np.ones((len(mesh_gar.faces), 4))*np.array([160, 160, 255, 200])[np.newaxis,:]

            mesh_gar.visual.face_colors = colors_g

            mesh_body_name = 'body_%03d_%d.obj'%(t,t)
            mesh_gar_name = 'bottom_%03d_%d.obj'%(t,t)
            mesh_body.export(os.path.join(save_dir, mesh_body_name), include_color=True)
            mesh_gar.export(os.path.join(save_dir, mesh_gar_name), include_color=True)

    body_current_record = []
    for t in range(len(garment_current_record)):
        garment_current_record[t] = trimesh.Trimesh(garment_current_record[t].squeeze().cpu().numpy(), faces_garment)
        body_current_record.append(trimesh.Trimesh(verts_body_t[t].squeeze().cpu().numpy(), faces_body))

    return body_current_record, garment_current_record

def init_smpl_sever(gender='f'):
    smpl_server = SMPL(model_path='.../smpl_pytorch',
                            gender=gender,
                            use_hands=False,
                            use_feet_keypoints=False,
                            dtype=torch.float32).cuda()
    return smpl_server

def infer_smpl(pose, beta, smpl_server):
    with torch.no_grad():
        output = smpl_server.forward_custom(betas=beta,
                                    body_pose=pose[:, 3:],
                                    global_orient=pose[:, :3],
                                    return_verts=True,
                                    return_full_pose=True,
                                    v_template=smpl_server.v_template, rectify_root=False)

    verts = output.vertices
    joints = output.joints
    root_J = output.joints[:,[0]]

    return verts, root_J

def interpolate_pose(pose_init, pose_t, steps=30):

    interval = (pose_t[0] - pose_init[0])/steps
    interval = interval.unsqueeze(0)

    pose_interp = [pose_init, pose_init, pose_init, pose_init]
    for i in range(steps):
        pose_i = pose_init + interval*i
        pose_interp.append(pose_i)

    pose_interp = torch.cat(pose_interp, dim=0)
    pose_new = torch.cat((pose_interp, pose_t), dim=0)

    pose_new_1 = pose_new[:-1]
    pose_new_2 = pose_new[1:]

    insert1 = (pose_new_2 - pose_new_1)/3 + pose_new_1
    insert2 = (pose_new_2 - pose_new_1)/3*2 + pose_new_1

    pose_new_insert = torch.zeros(len(pose_new_1)*3+1, 72).cuda()
    pose_new_insert[::3] = pose_new
    pose_new_insert[1::3] = insert1
    pose_new_insert[2::3] = insert2

    return pose_new_insert

align_dir = '...fitting-data/skirt/processed/align'
bottom_path = '.../fitting-data/skirt/processed/result-offset-new/'
top_path = '.../fitting-data/close-shirt/processed/result-round2-raster2faces-removeCollar/'
output_dir = '.../fitting-data/skirt/animation'

samples = sorted(os.listdir(bottom_path))

smpl_server = init_smpl_sever()

bottom = trimesh.load(os.path.join(bottom_path, sample, 'mesh_final_smooth_new.obj'))

material = Material()
cloth_bottom = Cloth_from_NP(bottom.vertices, bottom.faces, material)

idx_boundary_v, boundary_edges = select_boundary(bottom)
path_waist = get_connected_paths_skirt(bottom, idx_boundary_v, boundary_edges)[0]
idx_v_waist = list(set(path_waist))

v_waist = bottom.vertices[idx_v_waist]
base = trimesh.proximity.ProximityQuery(body_smpl)
closest_pt, _, closest_face_idx = base.on_surface(v_waist)
triangles = body_smpl.triangles[closest_face_idx]
v_barycentric = trimesh.triangles.points_to_barycentric(triangles, closest_pt)

body = Body(body_smpl.faces)

data_dir = '.../AMASS-CMU'
data_folder = 'fullset-separateArms'
seqfiles = os.listdir(os.path.join(data_dir, data_folder))
seqfiles = _default_seq(seqfiles)

for seq, files in seqfiles.items():
    count_t = 0
    pose_t = []
    for i in range(len(files)):
        seq_i = '_'.join(files[i].split('_')[:2])
        print(files[i])

        save_dir = os.path.join(output_dir, seq_i)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        pose_t_i, trans_vel = load_pose(i, files, data_dir, data_folder)
        pose_t.append(pose_t_i)

    pose_t = torch.cat(pose_t, dim=0).cuda()
    pose_t = interpolate_pose(pose, pose_t, steps=60)
    pose_t[0] = pose[0]
    garment_batch0 = torch.FloatTensor(bottom.vertices).cuda().unsqueeze(0)
    garment_batch1 = torch.FloatTensor(bottom.vertices).cuda().unsqueeze(0)

    body_current_record, garment_current_record = EFT_seq(garment_batch0, garment_batch1, pose_t.cuda(), beta, cloth_bottom, body, smpl_server, smpl_server.faces, idx_v_waist, closest_face_idx, v_barycentric, lr=1e-3, iters=400)
MontaEllis commented 3 weeks ago

Thanks a lot! However, I noticed some overlap between the garment and the body. Could this affect the simulation process?

liren2515 commented 3 weeks ago

You mean there is collision between the garment and body?