Open jonnyhe opened 2 years ago
Yes, this happens because VATT is optimized for TPU kernels.
If you'd like to bypass this error, you can replace this function with strategy.gather(tensor, axis=0)
OR simply replace this line with the following:
modality_1_all = modality_1
modality_2_all = modality_2
Hi @hassanhub, Thanks for your reply, i have replaced the function with strategy.gather(tensor, axis=0). However I got the following error instead. RuntimeError: tf.distribute.Strategy.gather method requires cross-replica context, use get_replica_context().all_gather() instead. It seems that tensor is not cross-replica context. Any idea to fix the problem? Thanks in advance.
Hi @hassanhub, Thanks for your reply, i have replaced the function with strategy.gather(tensor, axis=0). However I got the following error instead. RuntimeError: tf.distribute.Strategy.gather method requires cross-replica context, use get_replica_context().all_gather() instead. It seems that tensor is not cross-replica context. Any idea to fix the problem? Thanks in advance.
excuse me, can you tell me how you replace this function, i also need to pass the strategy?thank you very much
modality_1_all = modality_1
modality_2_all = modality_2
hello i have replace the line with modality_1_all = modality_1 modality_2_all = modality_2 but still have this problem ,and when i replace all this function,i still have the slice index 0 of dimension 0 out of bounds
Hi @hassanhub, sorry for the interruption. I have followed the instruction with the command 'python -m vatt.main --task=pretrain --mode=train --model_dir=PATH/TO/RUN --model_arch=tx_fac --strategy_type=mirrored' on the GPU environement. The TPU configuration is set to None. However it got the error as following, which seems related to the TPU. tensorflow.python.framework.errors_impl.InvalidArgumentError: No OpKernel was registered to support Op 'AllToAll' used by {{node StatefulPartitionedCall/AllToAllGather}} with these attrs: [split_count=4, concat_dimension=1, split_dimension=0, T=DT_FLOAT] Registered devices: [CPU, GPU] Registered kernels: