pytorch / torchchat

Run PyTorch LLMs locally on servers, desktop and mobile
BSD 3-Clause "New" or "Revised" License
3.34k stars 215 forks source link

[Distributed Inference] moving stage.submod to non-fp32 (bf16, fp16) results in dtensor assert "self.mask_buffer.data is not None" #1086

Closed lessw2020 closed 2 months ago

lessw2020 commented 2 months ago

🐛 Describe the bug

Using our prototype parallel blocks for built in distributed, we can run tp + pp in fp32 successfully. However, moving the model to bfloat16 or fp32 results in an embedding assert:

[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/pipelining/stage.py", line 530, in forward_one_chunk
[rank0]:[rank0]:     output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/pipelining/stage.py", line 464, in forward_maybe_with_nosync
[rank0]:[rank0]:     out_val = self.submod(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:[rank0]:     return forward_call(*args, **kwargs)
[rank0]:[rank0]:   File "/data/users/less/local/torchchat/build/model_dist.py", line 105, in forward
[rank0]:[rank0]:     x: DTensor = self.tok_embeddings(x)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1801, in _call_impl
[rank0]:[rank0]:     hook_result = hook(self, args, result)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 895, in <lambda>
[rank0]:[rank0]:     lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/tensor/parallel/style.py", line 251, in _prepare_output_fn
[rank0]:[rank0]:     outputs = outputs.redistribute(placements=output_layouts, async_op=True)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/api.py", line 541, in redistribute
[rank0]:[rank0]:     return Redistribute.apply(self, device_mesh, placements, async_op)
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/autograd/function.py", line 575, in apply
[rank0]:[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/_redistribute.py", line 295, in forward
[rank0]:[rank0]:     output = redistribute_local_tensor(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/_redistribute.py", line 196, in redistribute_local_tensor
[rank0]:[rank0]:     new_local_tensor = partial_spec._reduce_value(
[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/ops/_embedding_ops.py", line 119, in _reduce_value
[rank0]:[rank0]:     assert self.mask_buffer.data is not None
[rank0]:[rank0]: AssertionError

This issue is to track the debugging and resolution.

Versions

N/A

kwen2501 commented 2 months ago

Looking at the error stack:

[rank0]:[rank0]:   File "/home/less/local/miniconda3/envs/newserver/lib/python3.10/site-packages/torch/distributed/_tensor/ops/_embedding_ops.py", line 119, in _reduce_value
[rank0]:[rank0]:     assert self.mask_buffer.data is not None
[rank0]:[rank0]: AssertionError

This seems to be from the TP package. Regardless of the error itself, this seems a bit intrusive.

kwen2501 commented 2 months ago
assert self.mask_buffer.data is not None

Is this assert complaining that the weight (of embedding) is not loaded?

lessw2020 commented 2 months ago

I've found that if we set the activations we pass in for creating the pipelineStage are the same as the 'future' dtype we will run with the model, then I am able to run the model with real weights in the proper dtype. (in this case fp16, so we have to pass in fp16 sample activations and then things work as expected. Passing in fp32 and then running a half dtype is where things break with this error).

lessw2020 commented 2 months ago

I've further updated things to now store the checkpoint dtype in the internal lookup table and then we adjust accordingly. Thus things now work as expected and at the same time the user does not have to provide/care about this...it just works. This is ultimately a limitation of the built in parallelism, but we've now got it handled with an automated solution so closing this out.

Screenshot 2024-08-29 at 11 55 02 AM