NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.42k stars 1.4k forks source link

Fix DistributedTestBase for transformer distributed tests #1829

Closed xwang233 closed 3 months ago

xwang233 commented 3 months ago

Pytorch PR https://github.com/pytorch/pytorch/pull/131510 added kwargs fake_pg to the function signature of MultiProcessTestCase._run. See https://github.com/pytorch/pytorch/blob/19ff9059ebe1f946e65b82fb386ad0d7b6eb69d7/torch/testing/_internal/common_distributed.py#L583-L595

It's necessary to adapt this change in function signature, otherwise tests may fail, e.g.

root@ced723ee1a16:/opt/pytorch/apex/tests/L0# python run_transformer/test_microbatches.py -v NcclMicrobatchCalculatorTest.test_constant_microbatch_calculator
test_constant_microbatch_calculator (__main__.NcclMicrobatchCalculatorTest) ... INFO:numba.cuda.cudadrv.driver:init
Process process 0:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
TypeError: DistributedTestBase._run() got an unexpected keyword argument 'fake_pg'
Process process 1:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
TypeError: DistributedTestBase._run() got an unexpected keyword argument 'fake_pg'
FAIL

======================================================================
FAIL: test_constant_microbatch_calculator (__main__.NcclMicrobatchCalculatorTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/pytorch/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
    self._join_processes(fn)
  File "/opt/pytorch/pytorch/torch/testing/_internal/common_distributed.py", line 767, in _join_processes
    self._check_return_codes(elapsed_time)
  File "/opt/pytorch/pytorch/torch/testing/_internal/common_distributed.py", line 842, in _check_return_codes
    self.assertEqual(
  File "/opt/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3826, in assertEqual
    raise error_metas.pop()[0].to_error(
AssertionError: Scalars are not equal!

Expected 0 but got 1.
Absolute difference: 1
Relative difference: inf
Expected zero exit code but got 1 for pid: 723

----------------------------------------------------------------------
Ran 1 test in 3.608s

FAILED (failures=1)
xwang233 commented 3 months ago

cc @crcrpar