NVIDIA / framework-reproducibility

Providing reproducibility in deep learning frameworks
Apache License 2.0
423 stars 40 forks source link

Message passing neural network determinism thwarted by tf.math.segment_sum and tf.gather #25

Closed dangthatsright closed 4 years ago

dangthatsright commented 4 years ago

Hi @duncanriach,

First of all, thank you for the in depth overview of randomness on GPU and your work in improving them.

I am unfortunately unable to provide code for my issue but I was hoping to see what ideas you could possibly have.

Setup: Running tensorflow 2.2.0 with TF estimators on g4dn AWS instances. Not using keras models (although we do use tf.keras.regularizers.l2), but mostly tf.compat.v1.layers.dense and tf.matmul, tf.math.segment_sum. I am using a MPNN model if that helps haha. I set os.environ['TF_DETERMINISTIC_OPS'] = '1' as you suggest. I also have np.random.seed and tf.estimator.RunConfig(tf_random_seed) set.

What I've tried: Running it 3 times each on CPU for 5 epochs with 0.001 learning rate, on GPU for 5 epochs with 0.001, 0.0001, 0.00001 learning rate.

What I've found: All 3 runs on CPU are the same. All 3 runs on GPU are different, except 0.001 learning rate diverges much faster than 0.0001 which diverges much faster than 0.00001 learning rate. You can see the results here: image

I don't expect GPU to be fully deterministic, but in order to compare changes to my models/featurizations, the amount of variation for 0.001 learning on GPU is far too much. The variance for 0.00001 learning rate is very acceptable, and I was wondering if you have any ideas on what is causing this and how to mitigate this. Since CPU is fully deterministic, I would expect this to be purely a GPU issue. Does that seem correct?

duncanriach commented 4 years ago

Thanks for the appreciation, @dangthatsright.

Without being able to run your code, it's impossible to provide a definitive answer. However, if you're seeing prefect (bit-exact) reproducibility when running on a CPU but not when running on a GPU then it suggests that there is an op that is running on the GPU that is injecting non-determinism. More specifically, it suggests that (a) your trainable variables are being initialized reproducibly (i.e. before training they match between runs), that (b) your pseudorandom number generators are being reproducibly reset (seeded), and that (c) there is not some other non-GPU source of non-determinism involved.

One reason for a larger learning rate leading to more divergence could be that the difference between runs at each step is being amplified, increasing the chance of a divergence in the high-level path taken through the optimization space.

Regarding, "I don't expect GPU to be fully deterministic": it should be possible for the operation on the GPU to be perfectly (bit-exact) reproducible, and we've achieved, and witnessed, this for many models. You probably have one op that is injecting non-determinism either in the forward or backward direction and it's likely an op for which deterministic functionality has not yet been provided.

In the list of ops you mentioned, the one that stands out to me is tf.math.segment_sum. We have confirmed that tf.math.unsorted_segment_sum operates nondeterministically (in the forward direction) and this is now reported in Confirmed Current GPU-Specific Sources of Non-Determinism (With Solutions). I'm almost certain that both the sorted and unsorted versions rely on the same underlying nondeterministic CUDA kernels (in segment_reduction_ops.h and segment_reduction_ops_gpu.cu.cc).

For the next step, please will you run your system with the tf.math.segment_sum only on CPU and confirm that you get deterministic functionality, as follows:

with tf.device("/cpu:0"):
  sums = tf.math.segment_sum(data, segment_ids)
duncanriach commented 4 years ago

Update: I just confirmed that tf.math.segment_sum does indeed currently inject nondeterminism in the forward direction when running on a GPU. I have added info to the README with this commit.

Can you give me your first and last name and I'll add you to the credits in this repo?

dangthatsright commented 4 years ago

Thank you for the suggestion! That helped but it didn't fix it completely. I am currently in process of finding the other sources of non determinism in my dataset and estimator spec functions. I'll let you know which functions are cause in case they aren't documented already!

duncanriach commented 4 years ago

Cool. Thank you.

duncanriach commented 4 years ago

You might want to double-check that the sum of all of your trainable variables is exactly the same, on each run, before starting training. Also, make sure that an ops that use PRNGs and take seed parameters have those seed parameters set.

dangthatsright commented 4 years ago

Here's the update, I thought it might have been some dataset map issues (from your guide) but it seemed like those were actually fine. I also thought that it had something to do with the estimator spec since removing the with cpu stuff makes not deterministic.

Between the final layer and the backprop is the loss function which uses tensorflow.python.ops.math_ops.to_float and squared_difference and tensorflow.python.ops.losses.losses_impl.compute_weighted_loss. This is good, I thought the bug would be in here. Unfortunately wrapping this loss_op didn't seem to work.

Instead I had to wrap

with tf.device("/cpu:0"):
         with tf.control_dependencies(update_ops):
                minimize_op = optimizer.minimize(
                    loss_op, global_step=tf.compat.v1.train.get_global_step()
                )

which allowed it to work. This makes me think that the backprop itself is not deterministic? Is this possible? I haven't seen anything with this except for something about cross entropy loss but I am just using a mean squared error loss.

Additionally, when I am adding the with cpu device, I notice a pretty good slowdown. Which makes sense because you have to copy stuff back and forth between gpu and cpu.

Unfortunately for my model I ran into

2020-07-24 02:04:17.789852: W ./tensorflow/core/common_runtime/gpu/gpu_host_allocator.h:44] could not allocate pinned host memory of size: 4565042688
^C^C2020-07-24 02:04:19.582345: W tensorflow/core/framework/op_kernel.cc:1753] OP_REQUIRES failed at cwise_ops_common.cc:82 : Resource exhausted: OOM when allocating tensor with shape[4060,166,166] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator cpu

which seems like I ran out of RAM. This is suboptimal but I can switch to a larger instance - is there a general way you suggest I be using these with cpu without harming performance (memory and training time) that much?

I appreciate any thoughts you may have!

duncanriach commented 4 years ago

In most cases, nondeterminism is injected by an op's backward path (because backprop tends to involve large reductions). tf.math.segment_sum, discussed above, is an unusual case where the nondeterminism is injected in the forward path. So if you're using another op that is injecting nondeterminism, then it's most likely to be doing that as part of the backprop calculations. In other words, what you're seeing makes total sense.

At least half of the DL compute is in the backprop, which is why running all the backprop on the CPU will slow the training down a lot. Also, backprop requires the presence of the forward activations (the forward intermediate outputs) to calculate the back-propagated gradients (as well as all the intermediate gradient values), so it does tend to put an extra load on a memory compared with just the forward path. With the copying (duplication) between CPU and GPU memory, the overall memory footprint would likely also be even larger.

minimize calculates the gradients and then applies them to the trainable variables. It's possible to split minimize into compute_gradients followed by apply_gradients and hopefully you will find that you only need to run compute_gradients on the CPU to get bit-exact reproducibility. That will rule-out nondeterminism in the gradient application (which is unlikely in any case).

The nondeterminism debug tool, which I have not yet released publicly, can be inserted into the path of grads_and_vars, between compute_gradients (running on GPU) and apply_gradients (also running on the GPU), to isolate the op who's GPU backprop kernel is injecting the nondeterminism. If you were able to provide me with easy-to-use repro code then I could use the tool to isolate the source, my limited time permitting.

For now, at least, I would like you to try setting (one-by-one) the following parameters on minimize or compute_gradients (whichever you're using) while running everything except segment_sum on GPU to see if they happen to resolve the issue: gate_gradients=GATE_GRAPH, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_TREE, colocate_gradients_with_ops=True

Notes about those parameters:

dangthatsright commented 4 years ago

Thank you so much for all your help. Unfortunately, it would probably take me quite a while to condense the code but even then it would quite a headache because of the tf.estimator archetype. Alas, thanks to your hints, I was able to figure it out!

There was one tf.gather that I didn't wrap with cpu device originally. Since things worked without wrapping tf.gather with cpu and using cpu for the backwards pass, it makes me believe that tf.gather is non deterministic on the backward pass/ gradient computation but not the forward pass (or could be both and I got lucky?) (Also kinda strange that tf.gather uses non deterministic operations, haha), therefore setting colocate_gradients_with_ops=True did it! Unfortunately, it still slows down the code by quite a bit.

Edit to above: I realized you linked this issue and the very first thing they mention is that tf.gather is non deterministic during backprop, I am a fool!

Note it may also be that segment_sum is also non deterministic in the backwards pass.
One last question if you don't mind: If I have an op that uses cpu, then a couple ops then another op that uses cpu, is this better or worse than wrapping all those ops with cpu? I'm guessing this is a tradeoff between the speed of the op on cpu vs gpu and the time it takes to copy the memory?

Thanks again for the help on this non deterministic ride. Having learned a lot more about how tensorflow works and how to debug non determinism, I appreciate your work even more! Also, some hints for others who are debugging non deterministic issues:

Edit:

  1. Do your research and read Github issues thoroughly before you begin

  2. After you change something, always run it twice because the deterministic process changed. I spent a long while thinking something was still non deterministic when in fact it was because it differed from the established non deterministic model.

  3. Always look at all significant digits. Usually it's the 8th or 9th digit that is different. (I was looking at metrics cause it's a bit annoying with graph mode and tf.estimator)

  4. Confirm that CPU is deterministic. Then set the minimize op to cpu only, before binary searching your code for the non deterministic forward pass. Then hopefully setting the minimize op to gpu just works, otherwise try the colocating gradients.

duncanriach commented 4 years ago

Great that you got it working! So happy for you.

Could you give me your first and last name so that I can add it to the credits? (or I can add your handle if you prefer). It would also be great to know what company you represent (if any) so that I can add it to my list.

Regarding tf.gather: its backprop nondeterminism is actually already reported here in the main README.md. I'm now considering moving the names of those dependent ops up into the main tables. Those tables should be compact and comprehensive in terms of GPU-nondeterministic ops.

Regarding potential segment_sum nondeterminism in the backward direction: I've made a note to double check for that. I think it's very unlikely though.

Regarding mix-and-match of ops between CPU and GPU: you're correct about the trade-off. The question can only be answered for a specific case by comparing run-times each way.

Thanks again for the help on this non deterministic ride. Having learned a lot more about how tensorflow works and how to debug non determinism, I appreciate your work even more!

You're welcome. It's been fun.

Regarding your hints/tips: these are all very useful. I intend to add a guide for getting-your-model-to-deterministic at the top of the main README.md. I may include some or all of your tips to that.

Ideally, however, it should not be necessary for folks to search through the TensorFlow project issues. The information in the README.md of this repo should be a maximally comprehensive and up-to-date summary of the status of GPU-determinism in TensorFlow.

Closing.

dangthatsright commented 4 years ago

Haha great title change!

Sure, my name is Hao Shen and I'm working at Reverie Labs.

I've also found that you can use the trick from the issue above (https://github.com/tensorflow/tensorflow/issues/39751) to wrap tf.segment_sum calls to get deterministic behavior with GPU. This led to comparable performance as the non deterministic version before!

duncanriach commented 4 years ago

Excellent. Thank you, @dangthatsright. I added your name to the credits.

Thanks for adding more info to TF issue 39751. I'm going to look into developing a framework-determinism patch and/or a temporary upstream solution.

dangthatsright commented 3 years ago

Hey @duncanriach, another year another question :P with the upgrade to TF 2.4+ I'm starting to see

/miniconda/envs/reverie_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:8733: in segment_sum
    _ops.raise_from_not_ok_status(e, name)
/miniconda/envs/reverie_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:6897: in raise_from_not_ok_status
    six.raise_from(core._status_to_exception(e.code, message), None)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

value = None, from_value = None

>   ???
E   tensorflow.python.framework.errors_impl.UnimplementedError: Deterministic GPU implementation of sorted segment reduction op not available. [Op:SegmentSum]

based on my knowledge, I believe tensorflow is doing a check when TF_DETERMINISTIC_OPS=1 which raises an assert error that the segment reduction op is not available. I'm not entirely sure what is the workaround here?

duncanriach commented 3 years ago

Hey @dangthatsright, I think that this comment on issue 39751 may answer your question. Note also that stock TF 2.7 will likely have fully-deterministic segment reduction ops, assuming that the rollback of PR 51392 get resolved.