facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.81k stars 1.32k forks source link

Size Mismatch between tensors when rendering SMPL model #1758

Closed mhaeming closed 8 months ago

mhaeming commented 8 months ago

I'm trying to render a textured SMPL bodymodel using pytorch3d. The texture is taken from the SMPLitex repository, the mesh is the official "smpl_uv.obj" from https://smpl.is.tue.mpg.de/.

from matplotlib import pyplot as plt
import torch
from torchvision.io import read_image
from pytorch3d.io import load_obj
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    TexturesUV,
    TexturesVertex,
    FoVPerspectiveCameras,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    RasterizationSettings,
    PointLights,
)

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

texture_image = read_image("m_01_alb.002.png").to(device=device) / 255.0

verts, faces, aux = load_obj("data/smpl_uv.obj", device=device)

# tex = TexturesVertex(torch.ones_like(verts)[None])
tex = TexturesUV(texture_image[None], faces.verts_idx[None], aux.verts_uvs[None])

meshes = Meshes(verts=[verts], faces=[faces.verts_idx], textures=tex)

R, T = look_at_view_transform(2.7, 0, 180, device=device)
camera = FoVPerspectiveCameras(device=device, R=R, T=T)

raster_settings = RasterizationSettings(image_size=512)

lights = PointLights(location=[[0, 0, -3.0]], device=device)

renderer = MeshRenderer(
    rasterizer=MeshRasterizer(cameras=camera, raster_settings=raster_settings),
    shader=SoftPhongShader(device=device, cameras=camera, lights=lights),
)

images = renderer(meshes)

plt.imshow(images.squeeze().cpu().numpy()

The size mismatch occurs when the renderer calls https://github.com/facebookresearch/pytorch3d/blob/7566530669203769783c94024c25a39e1744e4ed/pytorch3d/renderer/mesh/shading.py#L96

Throwing a

RuntimeError: The size of tensor a (3) must match the size of tensor b (4096) at non-singleton dimension 4,

where 4096 is the width and height of the texture.

Am I overlooking a preprocessing step during TextureUV? The issue does not arise when using a TexturesVertex as in the commented out line. I would appreciate any idea what might be going wrong here. Thanks in advance!

bottler commented 8 months ago

The texture_image supplied to TexturesUV should be shape [B=1,H,W,C]. read_image from torchvision gives you [C,H,W] so you need to permute it.

mhaeming commented 8 months ago

Thanks a lot, that was the problem! Now I'm experiencing the same weird behavior as mentioned in #1601. But I think this issue can be closed.

dancasas commented 6 months ago

Hello @mhaeming, you can have a look at the SMPLitex repository on how to render textures SMPL meshes with PyTorch3D. I have it implemented here:

https://github.com/dancasas/SMPLitex/blob/0476822a8273f96ee8ab3463560d6968c46697c7/scripts/utils/renderer/pytorch3d_renderer.py#L189