NVlabs / nvdiffrast

Nvdiffrast - Modular Primitives for High-Performance Differentiable Rendering
Other
1.43k stars 158 forks source link

flat shading #102

Closed yoterel closed 1 year ago

yoterel commented 1 year ago

Hi, Is there any (performant) sensible way to flat shade a mesh (that doesn't use triangle_id form the rasterizer to index into a face normal tensor)? my current attempts work but it is just extremely slow relative to smooth shading:

i.e. color = (f_normals[rast_out[..., 3].long() - 1] + 1) / 2

where f_normals is a tensor of size Fx3, and rast_out is the output from the rasterize operation

s-laine commented 1 year ago

Hmm, I don't see why this would be slow, assuming that the tensors are on GPU and there are no surprise format conversions or such. Are you sure the slowness is because of this operation, or could you perhaps be computing the normal tensor in an inefficient way?

Further questions: How slow are we talking, compared to just interpolating vertex normals? Is the forward pass slow or only the backward pass?

If you're sure it's this line that's causing the problems, maybe torch.gather works better, although I'd imagine the way you've written it maps to the same operation anyway.

yoterel commented 1 year ago

I am fairly certain this is the cause for slowness, as I have timed the (forward+backward) together and it is approximately 10 times slower than smooth shading of the normals encoded as color (i.e. vertex normals are interpolated across faces). torch gather did not solve this problem, as the scattered indexing is indeed mapped to the same exact operation.

Is it possibly due to the gathering happening through python, and not using a dedicated cuda kernel ? nvdiffrast seems to internally perform this indexing as well when doing interpolation, so something must give.

s-laine commented 1 year ago

Ok, after some testing I can confirm that the backward pass of PyTorch's gather op can be extremely slow indeed. I'm impressed.

It is possible to use nvdiffrast's interpolation op for flat shading. You'll need to be a bit creative with the triangle array and exploit the fact that it doesn't need to be the same in rasterization and interpolation — this is demonstrated in, e.g., the "earth" example where position and texture coordinates have separate indices.

First create a custom triangle array that refers to a virtual vertex index, corresponding to the triangle index, three times for each triangle:

tri_facecolor = torch.arange(0, tri.shape[0], dtype=torch.int32, device='cuda')[:, None].expand(-1, 3).contiguous()

This gives you an index tensor that looks like

[[0, 0, 0],
 [1, 1, 1],
 [2, 2, 2],
 [3, 3, 3],
 ...
]

Now use this in place of the usual triangle array when shading. So instead of this:

# vertexcolor has one color per vertex, tri refers to vertex indices
color, _ = dr.interpolate(vertexcolor, rast_out, tri)

do

# facecolor has one color per triangle, tri_facecolor refers to triangle indices
color, _ = dr.interpolate(facecolor, rast_out, tri_facecolor)

and you should be good to go.