NVlabs / nvdiffrast

Nvdiffrast - Modular Primitives for High-Performance Differentiable Rendering
Other
1.35k stars 144 forks source link

Non deterministic results of optimization #13

Closed Daiver closed 3 years ago

Daiver commented 3 years ago

Hi! First of all - thank you for such a great library, I really enjoyed the performance and API design.

However, I faced with a non-deterministic behavior of nvdiffrast - optimization results differ a lot between runs with the same inputs. For my experiments I use LBFGS but I also observed this problem with Adam in a standard cube.py example.

I added the following lines at the beginning of the main function to fix the result of random functions deterministic between runs.

torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)

And values of loss differs between runs, sometimes final loss on the last iteration can differ up to 2 times between runs.

So I have the following questions:

  1. Is such a big difference between runs ok? I understand that nvdiffrast uses atomic operations which may lead to non-deterministic results but I'm surprised that difference accumulates so fast.
  2. Is it possible to reduce difference accumulation?
  3. Do you have any plans to add a deterministic version of your routines? As far as I know, such non-determinism is ok for DL related stuff. But it makes debugging of other tasks like mesh fitting much harder.

Please let me know if i can provide any additional information. And thank you in advance!

s-laine commented 3 years ago

Hi @Daiver! I tried running samples/torch/cube.py with --resolution 16 using your code block to enforce a fixed random seed. I tried 10 different seeds and 10 runs each, and the error at iteration 1000 always matched almost exactly across different runs on the same seed (largest relative difference between min and max was 1% on one of the seeds, and it generally seemed to be around 0.1%). The same test with --resolution 256 showed some more variation, at worst 3.5% but only about 1.7% on average. A longer learning rate rampdown would probably smooth these out as well. Do you see large differences even in these setups?

Answering the rest of your questions:

  1. Correct, nvdiffrast uses atomic operations and they almost certainly lead to some discrepancies between runs. How much runs diverge because of this depends entirely on how stable the optimization process itself is. Even small differences can amplify if the system is dynamically unstable, and conversely, if there's a robust path towards an optimum, it's unlikely that small amount of randomness causes the optimization to jump on a different trajectory that lands in a different optimum.
  2. Not really, unfortunately. You could split the computation into a number of smaller independent tiles, in order to reduce the number of atomics that target the same memory location, but I wouldn't expect this to help much.
  3. Being fully deterministic would in practice require removing atomics, and I cannot think of other ways to do the scatter-like accumulation of gradients in the backward pass even remotely efficiently. So the answer is no.

There is one additional potential source of nondeterminism I can think of, and that's the construction of topology hash for the antialiasing operation. This is nondeterministic if there are multiple possible ways to connect triangles at edges, i.e., you'd need to have more than two triangles that share an edge with the same vertex indices. In this rather pathological situation, the triangle pairs that get connected across the edge may change between executions.

Daiver commented 3 years ago

Hi @s-laine! Thank you for answer! I'm sorry, I should mention that i run sample with --resolution 512. I also used big resolutions (~1k) for my pet project when faced with visible differences between runs. I run cube.py for one seed for 10 times and obtained following loss values at 1000 iteration: [0.002488, 0.002448, 0.002285, 0.001401, 0.001348, 0.002026, 0.002836, 0.001974, 0.001888, 0.002228] min_value = 0.001348, max_value = 0.002836. I'm not sure how to compute relative difference properly, so I used this formula: (max_value - min_value) / max_value * 100 = 52.47%.

I also run it for --resolution 256 with different seed 10 times and obtained following results (same seed for each run) [0.000791, 0.000729, 0.001533, 0.000288, 0.000322, 0.000725, 0.001075, 0.000543, 0.000612, 0.000343], max is 0.001533, min is 0.000288 rel. diff is 81.21%

Same for --resolution 16: [0.000234, 0.000234, 0.000247, 0.000247, 0.000248, 0.000234, 0.000248, 0.000234, 0.000234, 0.000248], max is 0.000248, min is 0.000234, rel. diff is 5.64%. So it looks pretty stable for small resolution.

This is my cube.py https://gist.github.com/Daiver/d2115ab4a885ed654205588e7071af55 with seed which i used for "512 test". I run it on Ubuntu 20.04.1 inside nvdiffrast docker, my gpu is 1080Ti. I'm a little bit scared by difference between mine and yours results. Feel free to correct my computations or to ask for more information/tests.

There is one additional potential source of nondeterminism I can think of, and that's the construction of topology hash for the antialiasing operation. This is nondeterministic if there are multiple possible ways to connect triangles at edges, i.e., you'd need to have more than two triangles that share an edge with the same vertex indices.

Thank you for explanation! My mesh is "manifold", so it should not be a problem for me.

Even small differences can amplify if the system is dynamically unstable, and conversely, if there's a robust path towards an optimum, it's unlikely that small amount of randomness causes the optimization to jump on a different trajectory that lands in a different optimum.

Indeed. Unfortunately I interested in precise image registration with vertex level optimization variables which is ill-posed problem with tons of bad local minima. Do i understand correctly that backward of any nvdiffrast stage (rasterize, interpolate, etc) cannot be implemented effectively without atomic operation?

s-laine commented 3 years ago

Thanks for the test case, it helps a lot. It appears that in this resolution the optimization does not properly converge in 1000 iterations. Here's an example from an end of one run with --resolution 512:

...
iter=900,err=0.001407
iter=910,err=0.001146
iter=920,err=0.001218
iter=930,err=0.001580
iter=940,err=0.001673
iter=950,err=0.002051
iter=960,err=0.001863
iter=970,err=0.001922
iter=980,err=0.002425
iter=990,err=0.003042
iter=1000,err=0.002547

The vertices are clearly still jiggling around so it's not a surprise that the last value you get has a lot of randomness in it. Note that the differences are geometrically small even though the relative differences are quite big.

If I add --max-iter 2000 the error ends up at 0.000002 or so, indicating a fully converged result. As an alternative to increasing iteration count, I can tweak the learning rate schedule to ramp down more aggressively by changing lr_lambda in the scheduler constructor to max(0.01, 10**(-x*0.0010)). With the default 1000 iterations this gives final errors that are a lot smaller than with the default learning rate schedule, and they are also more consistent with each other. There is still some randomness but the absolute differences are vanishingly small.

The differences in the other resolutions may be explained by the fact that 1080ti does not have a warp match instruction. This means that the code cannot do warp-wide atomic coalescing, and each atomic update is executed individually. With coalescing, the updates to same memory location made at the same time from the same warp are first summed within the warp, and then a single atomic is executed to update the memory. The accuracy of these two approaches is probably different, but it's difficult to estimate how big of a difference there may be.

In general, you can try to avoid bad minima in a variety of ways. Adding stochastic noise to the optimization may help in large-scale exploration, and mesh regularization terms can be used for penalizing obviously unrealistic solutions. In addition, it may help to design a mesh parameterization/basis that makes it easy to change the mesh in plausible ways but difficult in implausible ways (e.g., moving a single vertex in a radically different way than its neighbors). These are still active research topics so it's hard to give any more concrete advice. But in my experience, optimizing vertex positions directly without any regularization is probably impossible to get working except in trivial cases like the cube test.

Your understanding about the backward pass is correct. At the very least it would require a bunch of extra computation, synchronization, and/or memory to ensure that either the atomics are executed in a consistent order, or the individual gradient contributions from each pixel are stored in memory as-is and summed together later in a consistent order.

Daiver commented 3 years ago

Sorry for late response. Thank you very much for such great and detailed explanation!

At the very least it would require a bunch of extra computation, synchronization, and/or memory to ensure that either the atomics are executed in a consistent order,

Nvdiffrast's paper is well written. Maybe one day i will try to implement deterministic version of it's backwards by myself. Sounds like a funny pet-project, even if it's impossible to make it fast :)

It's not related to my initial question, so i probably should close the issue. But I'll be happy to discuss registration stuff a little bit longer.

But in my experience, optimizing vertex positions directly without any regularization is probably impossible

Indeed. By "vertex-level optimization" I mean highly non rigid deformation case with some regularization like ARAP. Stochastic noise is interesting idea but it hard to tame for me. Unfortunately image registration is ill-posed optimization problem in many cases. Smaller error doesn't always mean better registration. For example if we optimize both lightning and geometry in many cases "right lightning-wrong geometry" gives lower error than "wrong lightning-right geometry". Or registration of skin covered objects like hands. In this case it's easy to stuck in bad minima due to wrong match between skin pores in base and target models. Also you still need high level of local non-rigidity to match complex skin movements. Obviously, there are a lot of papers which suggest different hacks and tricks to deal with such problems. And that's why determinism is vital - It's hard to understand if the result changed because of a trick, or because the optimization worked differently.