huggingface / nanotron

Minimalistic large language model 3D-parallelism training
Apache License 2.0
1.14k stars 107 forks source link

PyTorch profiler is unable to serialize numpy datatypes sometimes inserted as process group ranks #177

Open hatanp opened 4 months ago

hatanp commented 4 months ago

Some process groups are initialized with ranks in Numpy arrays and sometimes with lists. The Numpy datatypes cause issues when the profiler tries to serialize dist_info to a JSON:

[rank0]:     prof.step()
[rank0]:   File "/soft/applications/conda/2024-04-29/mconda3/lib/python3.11/site-packages/torch/profiler/profiler.py", line 727, in step
[rank0]:     self._transit_action(prev_action, self.current_action)
[rank0]:   File "/soft/applications/conda/2024-04-29/mconda3/lib/python3.11/site-packages/torch/profiler/profiler.py", line 744, in _transit_action
[rank0]:     action()
[rank0]:   File "/soft/applications/conda/2024-04-29/mconda3/lib/python3.11/site-packages/torch/profiler/profiler.py", line 177, in start_trace
[rank0]:     self.add_metadata_json("distributedInfo", json.dumps(dist_info))
[rank0]:                                               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/soft/applications/conda/2024-04-29/mconda3/lib/python3.11/json/__init__.py", line 231, in dumps
[rank0]:     return _default_encoder.encode(obj)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/soft/applications/conda/2024-04-29/mconda3/lib/python3.11/json/encoder.py", line 200, in encode
[rank0]:     chunks = self.iterencode(o, _one_shot=True)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/soft/applications/conda/2024-04-29/mconda3/lib/python3.11/json/encoder.py", line 258, in iterencode
[rank0]:     return _iterencode(o, 0)
[rank0]:            ^^^^^^^^^^^^^^^^^
[rank0]:   File "/soft/applications/conda/2024-04-29/mconda3/lib/python3.11/json/encoder.py", line 180, in default
[rank0]:     raise TypeError(f'Object of type {o.__class__.__name__} '
[rank0]: TypeError: Object of type int64 is not JSON serializable

You can inspect the process group ranks data type with:

dist.distributed_c10d._get_all_pg_configs()
configs = dist.distributed_c10d._get_all_pg_configs()
for config in configs:
     print(f"{config['ranks'][0].__class__.__name__}")

I could get around the issue by modifying nanotron.distributed.new_group and adding

    if isinstance(ranks, np.ndarray):
        ranks = ranks.tolist()

however I am not sure this is the fix for the long term. Looking around I see that ParallelContext.create_new_group has a type hint for np.ndarray but sometimes gets a list as well. Ideally these should always be in the same format and respect the type hints. Inserting regular integers to the torch.distributed.new_group is probably a better idea not to break things like the profiler. Torch documentation has the input as ranks being of datatype list[int] as well. An alternative to modification proposed above would be to have an assert here to ensure the input is a list and then modify the code elsewhere to only input lists to nanotron.distributed.new_group.

Python: 3.11.8 PyTorch: 2.3.0 nanotron: up to date main branch Config: created by examples/bench_llama_7b.py with profiler enabled:

profiler:
  profiler_export_path: ./checkpoints