The Rank of logsumexp Tensor must be 3, which was considered for internal use only but apparently exposed to UTs.
The stream should be selected after picking the current device according to input tensor. Not sure if it is critical to SWDEV-459623 but it is no harm to fix this as well.
Fixes #SWDEV-459623
Tested on rocm-framework-51. Log:
(py_3.10) xinyazha@12fd1640b6b7:~/rocm-pytorch$ PYTORCH_TEST_WITH_ROCM=1 python test/distributed/_tensor/test_attention.py -k test_ring_attention_compile_attention -v
test_ring_attention_compile_attention_fn0 (__main__.RingAttentionTest) ... ok
test_ring_attention_compile_attention_fn1 (__main__.RingAttentionTest) ... [rank1]:[W530 08:05:10.073449407 ProcessGroupNCCL.cpp:1113] WARNING: process group has NOT been destroyed before it is being destructed. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL data transfers have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4
[rank0]:[W530 08:05:10.123574261 ProcessGroupNCCL.cpp:1113] WARNING: process group has NOT been destroyed before it is being destructed. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL data transfers have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4
skipped 'Test skipped at subprocess level, look at subprocess log for skip reason'
----------------------------------------------------------------------
Ran 2 tests in 11.265s
OK (skipped=1)
Fixes #SWDEV-459623
Tested on
rocm-framework-51
. Log: