Open carmocca opened 8 months ago
I don't think this is an nvFuser issue. The nvFuser standalone repro does not fail. I wonder if it was just the place that the CUDA error first got caught. On an H100, I am seeing a different error with NCCL.
W0502 04:04:05.809000 140711848431488 torch/distributed/run.py:778]
W0502 04:04:05.809000 140711848431488 torch/distributed/run.py:778] *****************************************
W0502 04:04:05.809000 140711848431488 torch/distributed/run.py:778] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0502 04:04:05.809000 140711848431488 torch/distributed/run.py:778] *****************************************
[rank0]: Traceback (most recent call last):
[rank0]: File "/workspace/repro.py", line 13, in <module>
[rank0]: config = Config(block_size=256, padded_vocab_size=32000, n_layer=6, n_head=6, head_size=48, n_embd=288, rotary_percentage=1.0, parallel_residual=False, bias=False, _norm_class='RMSNorm', _mlp_class='LLaMAMLP', intermediate_size=768)
[rank0]: TypeError: Config.__init__() got an unexpected keyword argument '_norm_class'
[rank0]:[W502 04:04:17.154035416 ProcessGroupNCCL.cpp:1113] WARNING: process group has NOT been destroyed before it is being destructed. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL data transfers have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4
W0502 04:04:17.734000 140711848431488 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 1344 closing signal SIGTERM
E0502 04:04:18.298000 140711848431488 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 1343) of binary: /usr/bin/python
Traceback (most recent call last):
File "/usr/local/bin/torchrun", line 33, in <module>
sys.exit(load_entry_point('torch', 'console_scripts', 'torchrun')())
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 900, in main
run(args)
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 891, in run
elastic_launch(
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 132, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
repro.py FAILED
------------------------------------------------------------
Failures:
<NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
time : 2024-05-02_04:04:17
host : viking-prod-229.ipp2u1.colossus.nvidia.com
rank : 0 (local_rank: 0)
exitcode : 1 (pid: 1343)
error_file: <N/A>
traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
You are correct Kevin. This is not an nvfuser issue. The code was also using some removed arguments. I updated the description
There are 2 problems at play here:
jit(fsdp(model))
, we are incorrectly sharding the shared param twice which leads to index out of bounds error (as the size of the index is smaller than expected). Patch below fixes this problem.diff --git a/thunder/distributed/__init__.py b/thunder/distributed/__init__.py
index c9aa00a..5ae1554 100644
--- a/thunder/distributed/__init__.py
+++ b/thunder/distributed/__init__.py
@@ -13,6 +13,7 @@ from functools import partial
import torch
import torch.distributed as tdist
+from torch.utils.weak import WeakTensorKeyDictionary
import thunder.core.utils as utils
from thunder.core.proxies import DDPType
@@ -559,6 +560,9 @@ def _shard_params(
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device("cuda", local_rank)
+ # In case there is weight/param sharing, we don't want to shard the same param
+ # multiple times. We use `sharded_params` to keep track of already sharded param to avoid resharding it.
+ sharded_params = WeakTensorKeyDictionary()
# We will definitely change the sharding logic in the future
for module_name, submodule in module.named_modules():
# Materialize meta-parameters on-device if necessary.
@@ -581,7 +585,10 @@ def _shard_params(
# Note [FSDP Sharding]
# All internal code will assume that the parameters are sharded on the first dimension
for param_name, param in submodule.named_parameters(recurse=False, prefix=module_name):
+ if param in sharded_params:
+ continue
_shard_param(param, global_rank, world_size, param_name, allow_padding_for_fsdp=allow_padding_for_fsdp)
+ sharded_params[param] = True
def _shard_param(
NOTE: fsdp(jit(model))
works ok as it refers to the parameter from original model and creates a shallow copy and shards the shallow copy.
# idx: "cuda:0 i64[128, 256]"
# tos1: "cuda:0 f32[256, 24]"
# t_lm_head_weight: "cuda:0 f32[16000, 144]"
p2 = torch_all_gather_prim_impl(t_lm_head_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p2: "FUTURE cuda:0 f32[32000, 144]"
p20 = torch_all_gather_prim_impl(t_transformer_wte_weight, _torch_distributed_distributed_c10d_ProcessGroup_0, True) # p20: "FUTURE cuda:0 f32[32000, 144]"
where torch_all_gather_prim_impl
is the following snippet which creates a new output tensor for each call.
https://github.com/Lightning-AI/lightning-thunder/blob/7d6e540a6ec0bebcb6bedc5ceadbc82d1b982b65/thunder/executors/torchex.py#L1753-L1770
To tackle 2, I think we need to add some notion of aliasing. Related to inplace support https://github.com/Lightning-AI/lightning-thunder/issues/145 which also has to consider aliasing.
Could you please submit your fix for 1? It's a perfect solution to this problem.
For 2 I think Thunder JIT could recognize these situations and pass just one tensor to the computational trace.
š Bug
To Reproduce
Code:
Run with:
Error:
Removing one of:
makes the problem not appear
cc @carmocca @awaelchli @crcrpar