openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.61k stars 408 forks source link

Auto Clustering Leading to Invalid Argument Error #2793

Open hercule24 opened 1 year ago

hercule24 commented 1 year ago

Hi XLA Experts,

We are using Tensorflow (2.4) together with Horovod (0.23) to do distributed training. We turned on auto clustering via tf.config.optimizer.set_jit(True). However it throws the following error:

324227 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
324228 [14]<stderr>:    tmp_logs = self.train_function(iterator)
324229 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
324230 [14]<stderr>:    result = self._call(*args, **kwds)
324231 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/def_function.py", line 956, in _call
324232 [14]<stderr>:    filtered_flat_args)
324233 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/function.py", line 2943, in __call__
324234 [14]<stderr>:    filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
324235 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat
324236 [14]<stderr>:    ctx, args, cancellation_manager=cancellation_manager))
324237 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/function.py", line 560, in call
324238 [14]<stderr>:    ctx=ctx)
324239 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
324240 [14]<stderr>:    inputs, attrs, num_outputs)
324241 [14]<stderr>:tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
324242 [14]<stderr>:  (0) Invalid argument:  Trying to assign variable with wrong dtype. Expected INVALID got float
324243 [14]<stderr>:    [[{{node cond/else/_1/cond/StatefulPartitionedCall/Variable_320/cond/else/_22838/Variable_320/cond/Assign}}]]
324244 [14]<stderr>:    [[cond/else/_1/cond/StatefulPartitionedCall/assert_greater_equal/Assert/AssertGuard/branch_executed/_26438/_6409]]
324245 [14]<stderr>:  (1) Invalid argument:  Trying to assign variable with wrong dtype. Expected INVALID got float
324246 [14]<stderr>:    [[{{node cond/else/_1/cond/StatefulPartitionedCall/Variable_320/cond/else/_22838/Variable_320/cond/Assign}}]]
324247 [14]<stderr>:0 successful operations.
324248 [14]<stderr>:0 derived errors ignored. [Op:__inference_fn_with_cond_272149]
324249 [14]<stderr>:
324250 [14]<stderr>:Function call stack:
324251 [14]<stderr>:fn_with_cond -> fn_with_cond
324252 [14]<stderr>:

I am not sure if this is right place for me to ask this question, but it greatly helps if you could take a quick look and suggest on how I can further debug. Thank you in advance!

ddunl commented 1 year ago

I've shared this internally, hopefully someone who can address this will get back to you soon!

hercule24 commented 1 year ago

Hey @ddunl , is there any update?

ddunl commented 1 year ago

I still haven't heard back from anyone internally. I'd be interested to know if this problem still occurs on a more recent version of TF, I think TF 2.4 is nearly 2 years old

hercule24 commented 1 year ago

I am considering migrating to TF 2.11, it needs some work, but I am hoping to know more about the root cause before I migrate.

ddunl commented 1 year ago

I think it'll be easier for me to find someone who can help if you can give a minimal reproducible example, I think it's quite difficult to debug it without seeing the code

hercule24 commented 1 year ago

Well...I cannot share the code because company policy won't allow it I think. On the other hand, I guess it's thrown from this line: https://github.com/tensorflow/tensorflow/blob/v2.4.0/tensorflow/compiler/jit/xla_device_ops.cc#L66

Also, decorating with @tf.function(experimental_compile=True) throws unsupported op: No registered 'SparseTensorDenseMatMul' OpKernel for XLA_GPU_JIT devices compatible with node error while auto clustering throws the above error.

hercule24 commented 1 year ago

Hey @ddunl , thank you for taking the time looking into this.

After many trials and errors by removing layers or swapping out layers with dummy ones, I discovered that tf.keras.layers.Embedding was the problematic layer, specifically, it's the tf.gather used internally by Embedding layer that will throw the above error.

Do you know why tf.gather was not compatible? Will upgrading to TF 2.11 fix it?

ddunl commented 1 year ago

Unfortunately I don't know, I think you may have better luck filing an issue on tensorflow instead. Sorry I can't give a more helpful answer!!