NVlabs / nvdiffrast

Nvdiffrast - Modular Primitives for High-Performance Differentiable Rendering
Other
1.35k stars 144 forks source link

Unexpected small gradients at texture seams #34

Closed wpalfi closed 3 years ago

wpalfi commented 3 years ago

The following script optimizes the texture of two triangles towards zero.

import torch
import nvdiffrast.torch as dr
import matplotlib.pyplot as plt

def tensor(*args, **kwargs):
    return torch.tensor(*args, device='cuda', **kwargs)

img_size = 64
tex_size = 64
pos = tensor([[[-0.8, -0.8, 0, 1], [0.8, -0.8, 0, 1],
             [-0.8, 0.8, 0, 1], [0.8, 0.8, 0, 1]]], dtype=torch.float32)
tri = tensor([[0, 1, 2],[1, 3, 2]], dtype=torch.int32)
vert_uv = tensor([[[0.1, 0.1], [0.7, 0.1], [0.1, 0.7],
    [0.9, 0.9], [0.3, 0.9], [0.9, 0.3]]], dtype=torch.float32)
tri_uv = tensor([[0, 1, 2],[3, 4, 5]], dtype=torch.int32)
tex = torch.full((1, tex_size, tex_size, 1), dtype=torch.float32, fill_value=1, device='cuda', requires_grad=True)

rows = []
losses = []
glctx = dr.RasterizeGLContext()
optim = torch.optim.SGD([tex],lr=1e2)
for i in range(int(1e4)):
    optim.zero_grad()
    rast, rast_db = dr.rasterize(glctx, pos, tri, resolution=[img_size, img_size])
    uv, uv_da = dr.interpolate(vert_uv, rast, tri_uv, rast_db, diff_attrs='all')
    img = dr.texture(tex, uv, filter_mode='linear')#, uv_da)
    img = img * torch.clamp(rast[..., -1:], 0, 1) # Mask out background.
    loss = (img**2).mean()
    loss.backward()
    optim.step()
    rows.append(img[0,img_size//2,:,0].detach().cpu().numpy())
    losses.append(loss.item())

plt.subplot(2,2,1)
plt.imshow(tex[0].detach().cpu())
v = -.5 + tex_size * vert_uv[0,tri_uv[:,[0,1,2,0]].type(torch.long)].cpu().numpy()
plt.plot(v[0,:,0],v[0,:,1],'k')
plt.plot(v[1,:,0],v[1,:,1],'k')
plt.colorbar()
plt.title('tex')

plt.subplot(2,2,3)
plt.imshow(img[0].detach().cpu())
v = -.5 + img_size * (pos[0,tri[:,[0,1,2,0]].type(torch.long)]/2+.5).cpu().numpy()
plt.plot(v[0,:,0],v[0,:,1],'k')
plt.plot(v[1,:,0],v[1,:,1],'k')
plt.colorbar()
plt.title('image')

plt.subplot(2,2,2)
plt.plot(rows)
plt.title(f'image row {img_size//2}')
plt.xlabel('iteration')
plt.xscale('log')
plt.yscale('log')

plt.subplot(2,2,4)
plt.plot(losses)
plt.title('loss')
plt.xlabel('iteration')
plt.xscale('log')
plt.yscale('log')

plt.show()

I get the expected result when I run it with filter_mode='nearest': image

But when I switch to filter_mode='linear', the training of the pixels at the edges slows down after ~100 iterations and pixel values are almost stuck at small constant values. image

The plot looks similar with mipmapping enabled.

s-laine commented 3 years ago

This is an interesting observation! All I can think of is this: Linear interpolation creates more complex relationships between texel values and pixel colors, and the resulting optimization problem becomes somehow badly conditioned.

As each lookup is a blend from several texel values, there are multiple ways to approach the right answer, and this is a tricky situation to optimize. SGD clearly doesn't work well in this situation, but Adam reaches high accuracy quite rapidly with learning rates such as 1e-3 or 1e-2. The loss doesn't go all the way to zero, but this isn't very surprising because the computations are done in 32-bit floating point for performance reasons which limits the precision of the gradients.