facebookresearch / pytorch3d

PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
https://pytorch3d.org/
Other
8.49k stars 1.28k forks source link

PBR-inspired Shading #174

Open myaldiz opened 4 years ago

myaldiz commented 4 years ago

šŸš€ Feature

Currently PyTorch3D supports basic shading methods (Phong, Gouraud etc). Although it mainly focuses on Rasterization scheme, there is still room for more realistic and PBR-inspired shading models.

Motivation

Specular highlights are long-lived enemy for many techniques (e.g. photometric stereo), and PBR-inspired shading models can provide more realistic reflections, even with the rasterization scheme. Also, surface roughness can be embedded in shading.

Pitch

I suggest choosing suitable PBR-inspired technique and implement as a shading option in PyTorch3D, and Iā€™d like to help during the process with PR's. (ie: Cook-Torrance model might be good start)

(btw thank you @jcjohnson for your interest)

gkioxari commented 4 years ago

Hi @myaldiz! PyTorch3D's rendering feature was designed in a modular fashion such that any shading, including custom shaders, can be added to the pipeline. Your idea of implementing a PBR -inspired shading models sounds great and if you wish to contribute this feature to PyTorch3D that's even more fantastic! We are happy to review your PR. The CONTRIBUTING template gives you directions on what the expectations are for contributed features.

myaldiz commented 4 years ago

Hi @gkioxari! Thank you for fast reply, before implementation I would really appreciate your opinions and tips on possible directions!

jcjohnson commented 4 years ago

I think this would be a great feature to have. I'm not super familiar with PBR models, so I don't have a strong opinion on the which to implement.

If you can implement the shading model using the information we are already returning from the rasterizer, then it should be pretty straightforward and can follow the design of the other shaders:

  1. Implement a functional version of the shading model (similar to https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/shading.py#L47)
  2. Wrap it in a Shader object, following the API of our other Shaders (https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/shader.py#L29)

Also tagging @patricklabatut who has more experience with PBR than me!

myaldiz commented 4 years ago

I just had a chance to take a look at the codes more carefully. I implemented a version of the blinn-phong shader first to get accustomed with the PyTorch3D. Through this experience, I noticed several design choices that may be readjusted. Before my comments, I would like to point out that I did not inspect whole the code thoroughly, so please correct me if I am wrong.

Here is the implementation of Blinn-phong model, which is much simpler than current phong model and solving issues with the reflections above 90 degrees.

from pytorch3d.renderer.blending import (
    BlendParams,
    softmax_rgb_blend)
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh.texturing import (
    interpolate_texture_map, 
    interpolate_vertex_colors,
    interpolate_face_attributes)

from pytorch3d.renderer.utils import TensorProperties, convert_to_tensors_and_broadcast

from typing import Tuple
from torch import nn
import torch.nn.functional as F

def blinn_phong_shading(
    meshes, fragments, lights, cameras, materials, texels
) -> torch.Tensor:
    verts = meshes.verts_packed()  # (V, 3)
    faces = meshes.faces_packed()  # (F, 3)
    vertex_normals = meshes.verts_normals_packed()  # (V, 3)
    faces_verts = verts[faces]
    faces_normals = vertex_normals[faces]
    pixel_coords = interpolate_face_attributes(
        fragments.pix_to_face, fragments.bary_coords, faces_verts
    )
    pixel_normals = interpolate_face_attributes(
        fragments.pix_to_face, fragments.bary_coords, faces_normals
    )

    # Assume point light source for now
    light_direction = lights.location - pixel_coords
    distance_squared = (light_direction.norm(p=2, dim=-1) ** 2.0).clamp_min(1e-6)
    light_direction = F.normalize(light_direction, p=2, dim=-1, eps=1e-6)

    # Tensor conversion
    matched_tensors = convert_to_tensors_and_broadcast(
        pixel_normals, 
        lights.diffuse_color,
        lights.specular_color,
        light_direction,
        cameras.get_camera_center(),
        materials.shininess,
        device=pixel_normals.device)

    # Reshape tensors
    points_dims = pixel_normals.shape[1:-1]
    expand_dims = (-1,) + (1,) * len(points_dims)
    for i, tensor in enumerate(matched_tensors): 
        if tensor.shape != pixel_normals.shape:
            matched_tensors[i] = tensor.view(
                expand_dims if i==len(matched_tensors)-1 else expand_dims + (3,)
            )   # Careful reshaping shininess

    normals, light_diffuse_color, light_specular_color, \
        light_direction, camera_position, shininess = matched_tensors

    view_direction = camera_position - pixel_coords
    view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
    normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)

    diffuse_intensity = \
        F.relu(torch.sum(normals * light_direction, dim=-1)) / distance_squared
    light_diffuse = light_diffuse_color * diffuse_intensity[..., None]

    halfway_vector = light_direction + view_direction
    halfway_vector = F.normalize(halfway_vector, p=2, dim=-1, eps=1e-6)
    specular_intensity = (F.relu(
        torch.sum(normals * halfway_vector, dim=-1)
    ) ** shininess) / distance_squared
    light_specular = light_specular_color * specular_intensity[..., None]

    ambient = materials.ambient_color * lights.ambient_color
    diffuse = materials.diffuse_color * light_diffuse
    specular = materials.specular_color * light_specular
    if normals.dim() == 2 and points.dim() == 2:
        # If given packed inputs remove batch dim in output.
        ambient = ambient.squeeze()
        diffuse = diffuse.squeeze()
        specular = speculer.squeeze()

    colors = (ambient + diffuse) * texels + specular
    return colors

(Tested this through render_textured_mesh tutorial notebook, file)