facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.81k stars 1.32k forks source link

plot_scene doesn't plot textured meshes #696

Closed ck-amrahd closed 3 years ago

ck-amrahd commented 3 years ago

Hi, I am trying to see my rendered 3d mesh from different viewpoints and I am using pytorch3d.vis.plot_scene function for this. I load the mesh and apply new texture like this:

mesh.textures = TexturesUV(texture.to(device), faces_uvs=faces_uvs, verts_uvs=verts_uvs)

Then I use the plot_scene function as follows:

fig = plot_scene({"subplot": {"cow_mesh": mesh}})
fig.show()

But it plots 3D mesh that has a blue color and it doesn't have the texture that I applied. When I render it from a certain viewpoint and plot the image, I can see the texture correctly applied to the given mesh. Am I missing something? Thank you.

patricklabatut commented 3 years ago

The 3D Plotly-based visualization support for PyTorch3D objects only handles TexturesVertex (see the corresponding code for reference). So it will not be possible to visualize texture images applied to a mesh (via TexturesUV) in 3D, unfortunately. AFAIK it is actually not something that can be addressed in PyTorch3D, Plotly simply does not support rendering textured 3D mesh data.

What could however be done instead is sampling the texture image at the mesh vertices and use the sampled vertex colors in a TexturesVertex to visualize the correspondingly colored mesh in 3D with plot_scene() (which is supported).

Alternatively, you can also use the pytorch3d.vis.texture_vis.textureuv_image_matplotlib() function to visualize the 2D texture image with mesh vertices projected in texture space (see this tutorial for an example).

ck-amrahd commented 3 years ago

Thanks @patricklabatut, that makes sense. I am familiar with that textureuv_image_matplotlib tutorial, but what I need is a feature to perform rotations in 3D and see the output. I will try to sample texture per-vertex and apply that one instead of TextureUV. Thank you.

ck-amrahd commented 3 years ago

Hi, Just to make sure I do the things correctly. I don't find any function in Pytorch3D that does this explicitly. I need to iterate over faces and get the correspondence between vertex and UV and maybe average if one vertex is in multiple faces, am I right? It seems like I need to do this from the obj file. Please correct me if there's a better way. Thank you.

patricklabatut commented 3 years ago

I don't find any function in Pytorch3D that does this explicitly.

No, I don't think there is one ready to use specifically for this.

I need to iterate over faces and get the correspondence between vertex and UV and maybe average if one vertex is in multiple faces, am I right?

Yes, you can leverage TexturesUV.centers_for_image() to get the (x, y) locations of the vertex UV coordinates in the texture image and then just retrieve the corresponding color (maybe apply some bilinear interpolation). That would yield a sampled vertex color for each vertex UV coordinates. As a vertex may belong to multiple faces with different vertex UV coordinates in each of these faces, you should indeed find all the corresponding faces (and thus vertex UV coordinates and sampled vertex colors) that this vertex belong to and (for instance) average the corresponding vertex colors.

ck-amrahd commented 3 years ago

Thanks @patricklabatut for the prompt reply. It seems like I am in the right direction. Thank you.

ck-amrahd commented 3 years ago

Hi @patricklabatut, I have implemented this feature and I can paste my code here or contribute if you guys want to have this feature in Pytorch3D. Thank you.

patricklabatut commented 3 years ago

I can paste my code here or contribute if you guys want to have this feature in Pytorch3D.

Whatever works for you depending on your time: either a paste and we take it from there or a PR that may require a few iterations. I also don't think we have tried this internally, so if you could share results, that would be great, thanks!

ck-amrahd commented 3 years ago

Thanks @patricklabatut I did it last night, I will go over it again and then create a PR. Sure, I will share the results.

ck-amrahd commented 3 years ago

Hi @patricklabatut I have verified the code with the cow mesh. I will optimize the code tonight and create a PR.

ck-amrahd commented 3 years ago

Hi @patricklabatut, sorry for not being able to create a PR. I will paste the code here. I hope you guys will help to check it for correctness [I have tested as well and it seems like it works] and implement some cool features like interpolation. Currently, I just take the nearest one.

    # convert TextureUV to TextureVertex and check if get_vertex_color is working properly or not

    import torch
    from pytorch3d.io import load_objs_as_meshes
    from pytorch3d.renderer import TexturesVertex
    from Utils.utils import normalize_mesh, get_lights, get_renderer, get_vertex_color, get_cameras
    from Utils.plot import plot_original_and_generated, visualize_3d
    from pytorch3d.structures import Meshes

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    image_size = 1024
    num_plots = 10

    obj_filename = "./Data/cow_mesh/cow.obj"
    mesh = load_objs_as_meshes([obj_filename], device=device)

    # mesh with textureuv
    mesh_tuv = normalize_mesh(mesh)

    faces_uvs = mesh_tuv.textures.faces_uvs_list()
    verts_uvs = mesh_tuv.textures.verts_uvs_list()

    colors = get_vertex_color(mesh)
    verts_rgb = colors.unsqueeze(dim=0)
    textures = TexturesVertex(verts_features=verts_rgb)

    # mesh with vertex texture
    mesh_tv = Meshes(verts=[mesh.verts_packed()], faces=[mesh.faces_packed()], textures=textures)

    for i in range(num_plots):
        cameras = get_cameras(batch_size=1, device=device)
        # get lights
        lights = get_lights(device=device)
        # get renderer
        renderer = get_renderer(image_size=image_size, cameras=cameras, lights=lights, device=device)

        image_tuv = renderer(mesh_tuv, cameras=cameras, lights=lights)
        # render images: will produce a tensor of shape: [batch_size, image_size, image_size, 4(RGBA)]
        image_tv = renderer(mesh_tv, cameras=cameras, lights=lights)

        plot_original_and_generated(image_tuv, image_tv)

    visualize_3d(mesh_tv, calc_vertex_texture=False)

    def visualize_3d(mesh, calc_vertex_texture=True):
        # render the image for 3d visualization
        # pytorch3d doesn't support TextureUV and rotation in 3D [plotly issue]
        # Instead sample texture per vertex and apply that one
        if calc_vertex_texture:
            print("Please wait, generating vertex colors from texture...")
            colors = get_vertex_color(mesh)
            verts_rgb = colors.unsqueeze(dim=0)
            textures = TexturesVertex(verts_features=verts_rgb)
            mesh = Meshes(verts=[mesh.verts_packed()], faces=[mesh.faces_packed()], textures=textures)

        fig = plot_scene({"": {"mesh": mesh}},
                         xaxis={"showgrid": False, "zeroline": False, "visible": False},
                         yaxis={"showgrid": False, "zeroline": False, "visible": False},
                         zaxis={"showgrid": False, "zeroline": False, "visible": False},
                         axis_args=AxisArgs(showgrid=False))
        # fig = plot_batch_individually(mesh, viewpoint_cameras=cameras)
        fig.show()

    class Vertex:
        def __init__(self, idx, position):
            self.idx = idx
            self.position = position  # [x, y, z]
            self.color = []  # sum of [r, g, b] for all occurrence of vertices
            self.update_count = 0

        def update_color(self, new_color):
            if not self.color:
                self.color = new_color
            else:
                self.color = map(lambda x, y: x + y, self.color, new_color)
            self.update_count += 1

    def get_vertex_color(mesh):
        # returns [r, g, b] for each vertex
        vertex_dict = {}
        for idx, pos in enumerate(mesh.verts_packed()):  # V x 3
            vertex_dict[idx] = Vertex(idx, pos.tolist())

        centers = mesh.textures.centers_for_image(index=0).numpy()
        for idx, face in enumerate(mesh.faces_packed()):  # F x 3
            v1_idx, v2_idx, v3_idx = face.tolist()
            # mesh.textures.faces_uvs_list()[0] --> F x 3
            vt1_idx, vt2_idx, vt3_idx = mesh.textures.faces_uvs_list()[0][idx].tolist()

            x1, y1 = centers[vt1_idx]
            color = mesh.textures.maps_padded()[0][int(y1), int(x1)]
            vertex_dict[v1_idx].update_color(color.tolist())

            x2, y2 = centers[vt2_idx]
            color = mesh.textures.maps_padded()[0][int(y2), int(x2)]
            vertex_dict[v2_idx].update_color(color.tolist())

            x3, y3 = centers[vt3_idx]
            color = mesh.textures.maps_padded()[0][int(y3), int(x3)]
            vertex_dict[v3_idx].update_color(color.tolist())

        colors = torch.ones_like(mesh.verts_packed())

        for idx, item in vertex_dict.items():
            colors[idx] = torch.FloatTensor([x / item.update_count for x in item.color])

        return colors

    def normalize_mesh(mesh):
        # normalize and center the target mesh so that we can have same light and similar R, T for camera
        # that works on every objects
        verts = mesh.verts_packed()
        n = verts.shape[0]
        center = verts.mean(0)
        scale = max((verts - center).abs().max(0)[0])
        mesh.offset_verts_(-center.expand(n, 3))
        mesh.scale_verts_((1.0 / float(scale)))
        return mesh

    def get_cameras(batch_size, device):
        # sample new camera location
        dist = round(random.uniform(2.0, 3.0), 2)
        # generate batch of meshes and rander them: real
        # meshes = mesh.extend(batch_size)

        # Get a batch of viewing angles.
        # This gets same view every time --> Try to randomize it

        # elev = torch.linspace(0, 180, batch_size)
        # azim = torch.linspace(-180, 180, batch_size)

        # Randomization: sample batch_size*2 values for each, randomly permute them and select the batch_size values
        elev_double = torch.linspace(0, 180, batch_size * 2)
        azim_double = torch.linspace(-180, 180, batch_size * 2)

        indices = torch.randperm(batch_size * 2)[:batch_size]
        elev = elev_double[indices]
        azim = azim_double[indices]

        R, T = look_at_view_transform(dist, elev=elev, azim=azim)
        cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
        return cameras

    def get_lights(device):
        # Place a point light in front of the object. As mentioned above, the front of the cow is facing the
        # -z direction.
        lights = PointLights(device=str(device), location=[[2.0, 0.0, 0.0]])
        return lights

    def get_renderer(image_size, cameras, lights, device):
        raster_settings = RasterizationSettings(
            image_size=image_size,
            blur_radius=0.0,
            faces_per_pixel=1,
        )

        # Create a phong renderer by composing a rasterizer and a shader. The textured phong shader will
        # interpolate the texture uv coordinates for each vertex, sample from a texture image and
        # apply the Phong lighting model
        renderer = MeshRenderer(
            rasterizer=MeshRasterizer(
                cameras=cameras,
                raster_settings=raster_settings
            ),
            shader=SoftPhongShader(
                device=device,
                cameras=cameras,
                lights=lights
            )
        )

        return renderer

    def plot_original_and_generated(original, generated):
        plt.figure(figsize=(7, 7))
        plt.subplot(1, 2, 1)
        plt.imshow(original[0, ..., :3].cpu().numpy())
        plt.title("Rendered with original texture")
        plt.grid("off")
        plt.axis("off")

        plt.subplot(1, 2, 2)
        plt.imshow(generated[0, ..., :3].cpu().numpy())
        plt.title("Rendered with synthetic texture")
        plt.grid("off")
        plt.axis("off")
        plt.show()

I have copied code from multiple files from my project, hope it will not take time to organize it.