It's necessary to adapt this change in function signature, otherwise tests may fail, e.g.
root@ced723ee1a16:/opt/pytorch/apex/tests/L0# python run_transformer/test_microbatches.py -v NcclMicrobatchCalculatorTest.test_constant_microbatch_calculator
test_constant_microbatch_calculator (__main__.NcclMicrobatchCalculatorTest) ... INFO:numba.cuda.cudadrv.driver:init
Process process 0:
Traceback (most recent call last):
File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
TypeError: DistributedTestBase._run() got an unexpected keyword argument 'fake_pg'
Process process 1:
Traceback (most recent call last):
File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
TypeError: DistributedTestBase._run() got an unexpected keyword argument 'fake_pg'
FAIL
======================================================================
FAIL: test_constant_microbatch_calculator (__main__.NcclMicrobatchCalculatorTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/opt/pytorch/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
self._join_processes(fn)
File "/opt/pytorch/pytorch/torch/testing/_internal/common_distributed.py", line 767, in _join_processes
self._check_return_codes(elapsed_time)
File "/opt/pytorch/pytorch/torch/testing/_internal/common_distributed.py", line 842, in _check_return_codes
self.assertEqual(
File "/opt/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3826, in assertEqual
raise error_metas.pop()[0].to_error(
AssertionError: Scalars are not equal!
Expected 0 but got 1.
Absolute difference: 1
Relative difference: inf
Expected zero exit code but got 1 for pid: 723
----------------------------------------------------------------------
Ran 1 test in 3.608s
FAILED (failures=1)
Pytorch PR https://github.com/pytorch/pytorch/pull/131510 added kwargs
fake_pg
to the function signature ofMultiProcessTestCase._run
. See https://github.com/pytorch/pytorch/blob/19ff9059ebe1f946e65b82fb386ad0d7b6eb69d7/torch/testing/_internal/common_distributed.py#L583-L595It's necessary to adapt this change in function signature, otherwise tests may fail, e.g.