facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.7k stars 1.3k forks source link

Bad rendering results on A100 GPU, but right in V100 or Tesla T4 GPU. #1803

Closed hjrPhoebus closed 1 month ago

hjrPhoebus commented 4 months ago

Thank you for your great work. However, I encountered an issue while trying to render the same mesh on V100 and A100 GPUs. On the V100, the results seem normal, but on the A100, the rendered output appears wrinkled. I used the same Docker environment for both GPUs, so the environment should be identical.

企业微信截图_af9ecc86-1055-4a7e-a7d4-d4dfb5a2fcd6 企业微信截图_5d6809cf-149d-4c1b-8ed3-4b78c8d7aa3a 企业微信截图_99f4daea-69e9-4097-a27a-3dfd3550f062 企业微信截图_3ba91330-1805-4433-95a4-307629004655

Here is the code I used to render the normal map. The same issue occurs when rendering the depth map. This makes it difficult for me to determine whether the problem lies with my mesh file or with PyTorch3D itself.

import os
import sys
import glob
import torch
import pickle

# import matplotlib.pyplot as plt
import imageio
import numpy as np

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes, load_obj
import torch.nn.functional as F

# from skimage.io import imread, imsave

# Data structures and functions for rendering
from pytorch3d.structures import Meshes

# from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
from pytorch3d.transforms import RotateAxisAngle

from pytorch3d.renderer import (
    BlendParams,
    look_at_view_transform,
    FoVPerspectiveCameras,
    FoVOrthographicCameras,
    PointLights,
    DirectionalLights,
    Materials,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    SoftSilhouetteShader,
    TexturesUV,
    TexturesVertex,
    blending,
    PerspectiveCameras,
)
from PIL import ImageColor
import argparse

# add path for demo utils functions
from tqdm import tqdm

# Setup device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

def render_normal(verts, faces, save_pth, angle, size, device=device):
    # rorate 保证smplx渲染的时候都是正面光        
    rot = RotateAxisAngle(angle, "Y", device=device)
    rot_verts = rot.transform_points(verts)
    mesh_smplx = Meshes(verts=[rot_verts], faces=[faces.verts_idx])

    mesh_smplx.textures = TexturesVertex(
        verts_features=(mesh_smplx.verts_normals_padded() + 1.0) * 0.5
    )
    # scale
    vertices = mesh_smplx.verts_list()[0]
    vertices = vertices.cpu().numpy()
    up_axis = 1
    scan_scale = 1.8 / (vertices.max(0)[up_axis] - vertices.min(0)[up_axis])
    mesh_smplx.scale_verts_(scan_scale)
    # center
    vertices = mesh_smplx.verts_list()[0]
    vertices = vertices.cpu().numpy()
    center_smpl =  (vertices.max(0) + vertices.min(0)) / 2

    offset = torch.tensor(0 - center_smpl).to(device)
    mesh_smplx.offset_verts_(offset)

    bg = "black"
    blendparam = BlendParams(1e-4, 1e-8, np.array(ImageColor.getrgb(bg)) / 255.0)

    raster_settings_mesh = RasterizationSettings(
        image_size=size,
        blur_radius=np.log(1.0 / 1e-4) * 1e-7,
        # bin_size=-1,
        bin_size=0,
        faces_per_pixel=1,
    )

    R, T = look_at_view_transform(1.6, 0, 0)
    # cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
    cameras = FoVOrthographicCameras(device=device, R=R, T=T)

    meshRas = MeshRasterizer(cameras=cameras, raster_settings=raster_settings_mesh)
    renderer = MeshRenderer(
        rasterizer=meshRas,
        shader=cleanShader(blend_params=blendparam),
    )
    images = renderer(mesh_smplx)
    rendered_image = images[0, ..., :3].detach().cpu().numpy()
    imageio.imsave(save_pth, (rendered_image * 255).astype(np.uint8))

def render(pth):
    normal_pth = "normal.png"

    smplx_obj_path = pth

    # angle_deg = float(np.loadtxt(front_deg_path))

    ### load smplx mesh
    verts, faces, aux = load_obj(smplx_obj_path, device=device)

    # rot = RotateAxisAngle(-angle_deg, "Y", device=device)
    rot = RotateAxisAngle(180, "X", device=device)
    verts = rot.transform_points(verts)

    # 都已经转到正面了
    render_normal(verts, faces, normal_pth, 0, 512)
    # render_depth(verts, faces, normal_pth, 0, 512)

if __name__ == "__main__":
    render("mesh.obj") 

https://cloud.tsinghua.edu.cn/f/3d7f383cd3c84f78add7/?dl=1 is the mesh file I used for rendering, which was obtained using SMPLX and Trimesh. In the above code denoted as "mesh.obj" file.

bottler commented 1 month ago

I'm sorry I have absolutely no idea why there should be a difference. Perhaps you could log lots of intermediate results and compare them between the environments.