google-research / google-research

Google Research
https://research.google
Apache License 2.0
34.14k stars 7.88k forks source link

[VATT]No OpKernel was registered to support Op 'AllToAll' #962

Open jonnyhe opened 2 years ago

jonnyhe commented 2 years ago

Hi @hassanhub, sorry for the interruption. I have followed the instruction with the command 'python -m vatt.main --task=pretrain --mode=train --model_dir=PATH/TO/RUN --model_arch=tx_fac --strategy_type=mirrored' on the GPU environement. The TPU configuration is set to None. However it got the error as following, which seems related to the TPU. tensorflow.python.framework.errors_impl.InvalidArgumentError: No OpKernel was registered to support Op 'AllToAll' used by {{node StatefulPartitionedCall/AllToAllGather}} with these attrs: [split_count=4, concat_dimension=1, split_dimension=0, T=DT_FLOAT] Registered devices: [CPU, GPU] Registered kernels:

Do you the idea for the error?
hassanhub commented 2 years ago

Yes, this happens because VATT is optimized for TPU kernels. If you'd like to bypass this error, you can replace this function with strategy.gather(tensor, axis=0)

OR simply replace this line with the following: modality_1_all = modality_1 modality_2_all = modality_2

jonnyhe commented 2 years ago

Hi @hassanhub, Thanks for your reply, i have replaced the function with strategy.gather(tensor, axis=0). However I got the following error instead. RuntimeError: tf.distribute.Strategy.gather method requires cross-replica context, use get_replica_context().all_gather() instead. It seems that tensor is not cross-replica context. Any idea to fix the problem? Thanks in advance.

cooper12121 commented 1 year ago

Hi @hassanhub, Thanks for your reply, i have replaced the function with strategy.gather(tensor, axis=0). However I got the following error instead. RuntimeError: tf.distribute.Strategy.gather method requires cross-replica context, use get_replica_context().all_gather() instead. It seems that tensor is not cross-replica context. Any idea to fix the problem? Thanks in advance.

excuse me, can you tell me how you replace this function, i also need to pass the strategy?thank you very much

cooper12121 commented 1 year ago

modality_1_all = modality_1 modality_2_all = modality_2

hello i have replace the line with modality_1_all = modality_1 modality_2_all = modality_2 but still have this problem ,and when i replace all this function,i still have the slice index 0 of dimension 0 out of bounds