pytorch / PiPPy

Pipeline Parallelism for PyTorch
BSD 3-Clause "New" or "Revised" License
664 stars 79 forks source link

FSDP+PP tracer issue with cast-to-bf16 #1104

Open wconstab opened 2 months ago

wconstab commented 2 months ago

https://github.com/pytorch/torchtitan/pull/161/files#diff-80b04fce2b861d9470c6160853441793678ca13904dae2a9b8b7145f29cd017aR254

image

In principle, the issue is that the PP model code traced the non-FSDP model, and in that case, the model code ran a .to(f32) operation which was a no-op and dropped out of the trace, or something like that.

the only proposal i recall was to change the tracer/export to handle this better and not drop the .to operation. Need to check if this has already been resolved.

cc @zhxchen17 @kwen2501

wconstab commented 2 months ago

https://github.com/pytorch/pytorch/pull/123732 was intended to help this case but isn't quite enough. 1) #123732 does not appear to help for calls to .float() - it only seems to work for explicit calls to .to(). I verified that if I replace .float() calls with .to(torch.float32) calls, the errors I previously saw went away.

cc @zhxchen17

2) after the changes for (1) i get a new error which i'm still trying to understand

image

I don't see any explicit casting operators here, and from the looks of it FSDP is expected to cast the layer inputs to bf16 but isnt, OR perhaps the inputs are in bf16 but for some reason the parameters are not?

These are the inputs to the exact op (bmm) that threw the exception target = <OpOverload(op='aten.bmm', overload='default')> args = [ (torch.Size([64, 2048, 2048]), torch.float32), (torch.Size([64, 2048, 16]), torch.bfloat16), ]

This paste shows one level higher in the grph- the whole attention module. https://www.internalfb.com/phabricator/paste/view/P1229735878

Note the traced code burns float_32 dtype kwarg into the view calls for xq, xk, xv, while the actual model code does not call float32 as part of the view.

I think this is the bug?

image image
kwen2501 commented 1 month ago

An example program shows that torch.export would not burn dtype into the ExportedProgram at trace time: https://github.com/kwen2501/export-playground/blob/main/dtype.py See the kwargs for zeros_like.

$ python dtype.py

opcode         name                target                   args                     kwargs
-------------  ------------------  -----------------------  -----------------------  ---------------------
placeholder    p_embedding_weight  p_embedding_weight       ()                       {}
placeholder    x                   x                        ()                       {}
call_function  embedding           aten.embedding.default   (p_embedding_weight, x)  {}
call_function  zeros_like          aten.zeros_like.default  (embedding,)             {'pin_memory': False}
output         output              output                   ((zeros_like,),)         {}
kwen2501 commented 1 month ago

The zeros_like's dtype in the issue's program is likely the one that causes the dtype mismatch at bmm.

We can use graph_module.print_readable() to see the original stack trace to identify which part of the code set it to FP32.

kwen2501 commented 1 month ago

Also confirmed that a line like this z = torch.zeros_like(y, dtype=y.dtype) would burn dtype into the kwargs:

# in forward, code: z = torch.zeros_like(y, dtype=y.dtype)
zeros_like: "f32[2, 4, 3]" = torch.ops.aten.zeros_like.default(embedding, dtype = torch.float32, pin_memory = False);  
kwen2501 commented 1 month ago

The doc of torch.zeros_like says:

torch.zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format)

dtype (torch.dtype, optional) – the desired data type of returned Tensor. Default: if None, defaults to the dtype of input.

Thus, it is safe to just write:

z = torch.zeros_like(y)

instead of

z = torch.zeros_like(y, dtype=y.dtype)

AI: we'd need to find out the code that is in the 2nd style above and fix it.

kwen2501 commented 1 month ago

Exporting the llama model and printing the stack shows me that the zeros_like is from the scaled_dot_product_attention

        # File: /data/users/kw2501/torchtitan/torchtitan/models/llama/model.py:203 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
        mul_4: "f32[8, 16, 2048, 16]" = torch.ops.aten.mul.Scalar(transpose, 0.5);  transpose = None
        ones: "b8[2048, 2048]" = torch.ops.aten.ones.default([2048, 2048], dtype = torch.bool, layout = torch.strided, device = device(type='meta'))
        tril: "b8[2048, 2048]" = torch.ops.aten.tril.default(ones);  ones = None
        zeros_like: "f32[2048, 2048]" = torch.ops.aten.zeros_like.default(tril, dtype = torch.float32)
        logical_not: "b8[2048, 2048]" = torch.ops.aten.logical_not.default(tril);  tril = None
        masked_fill: "f32[2048, 2048]" = torch.ops.aten.masked_fill.Scalar(zeros_like, logical_not, -inf);  zeros_like = logical_not = None
        transpose_3: "f32[8, 16, 16, 2048]" = torch.ops.aten.transpose.int(transpose_1, -2, -1);  transpose_1 = None
        mul_5: "f32[8, 16, 16, 2048]" = torch.ops.aten.mul.Scalar(transpose_3, 0.5);  transpose_3 = None
        expand: "f32[8, 16, 2048, 16]" = torch.ops.aten.expand.default(mul_4, [8, 16, 2048, 16]);  mul_4 = None
        clone: "f32[8, 16, 2048, 16]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format);  expand = None
        _unsafe_view: "f32[128, 2048, 16]" = torch.ops.aten._unsafe_view.default(clone, [128, 2048, 16]);  clone = None
        expand_1: "f32[8, 16, 16, 2048]" = torch.ops.aten.expand.default(mul_5, [8, 16, 16, 2048]);  mul_5 = None
        clone_1: "f32[8, 16, 16, 2048]" = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format);  expand_1 = None
        _unsafe_view_1: "f32[128, 16, 2048]" = torch.ops.aten._unsafe_view.default(clone_1, [128, 16, 2048]);  clone_1 = None
        bmm: "f32[128, 2048, 2048]" = torch.ops.aten.bmm.default(_unsafe_view, _unsafe_view_1);  _unsafe_view = _unsafe_view_1 = None
        view_14: "f32[8, 16, 2048, 2048]" = torch.ops.aten.view.default(bmm, [8, 16, 2048, 2048]);  bmm = None
        add_1: "f32[8, 16, 2048, 2048]" = torch.ops.aten.add.Tensor(view_14, masked_fill);  view_14 = masked_fill = None
        _softmax: "f32[8, 16, 2048, 2048]" = torch.ops.aten._softmax.default(add_1, -1, False);  add_1 = None
        expand_2: "f32[8, 16, 2048, 2048]" = torch.ops.aten.expand.default(_softmax, [8, 16, 2048, 2048]);  _softmax = None
        view_15: "f32[128, 2048, 2048]" = torch.ops.aten.view.default(expand_2, [128, 2048, 2048]);  expand_2 = None
        expand_3: "f32[8, 16, 2048, 16]" = torch.ops.aten.expand.default(transpose_2, [8, 16, 2048, 16]);  transpose_2 = None
        clone_2: "f32[8, 16, 2048, 16]" = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format);  expand_3 = None
        _unsafe_view_2: "f32[128, 2048, 16]" = torch.ops.aten._unsafe_view.default(clone_2, [128, 2048, 16]);  clone_2 = None
        bmm_1: "f32[128, 2048, 16]" = torch.ops.aten.bmm.default(view_15, _unsafe_view_2);  view_15 = _unsafe_view_2 = None
        view_16: "f32[8, 16, 2048, 16]" = torch.ops.aten.view.default(bmm_1, [8, 16, 2048, 16]);  bmm_1 = None
kwen2501 commented 1 month ago

More specifically:

https://github.com/pytorch/pytorch/blob/a5c93a6899c657832944cd2eeb5069449e28dbea/aten/src/ATen/native/transformers/attention.cpp#L523

kwen2501 commented 1 month ago

CC: @zhxchen17 @tugsbayasgalan let me know you are preparing an improvement to unburn the dtype as well? (in addition to device). We will be thrilled to try that out. CC: @wconstab

kwen2501 commented 1 month ago

Meanwhile, @tugsbayasgalan mentioned that pre-dispatch mode is now the default mode of torch.export. That can also work around this issue by using this new mode to avoid tracing into SPDA.