Open wconstab opened 6 months ago
https://github.com/pytorch/pytorch/pull/123732 was intended to help this case but isn't quite enough.
1) #123732 does not appear to help for calls to .float()
- it only seems to work for explicit calls to .to()
. I verified that if I replace .float()
calls with .to(torch.float32)
calls, the errors I previously saw went away.
cc @zhxchen17
2) after the changes for (1) i get a new error which i'm still trying to understand
I don't see any explicit casting operators here, and from the looks of it FSDP is expected to cast the layer inputs to bf16 but isnt, OR perhaps the inputs are in bf16 but for some reason the parameters are not?
These are the inputs to the exact op (bmm) that threw the exception target = <OpOverload(op='aten.bmm', overload='default')> args = [ (torch.Size([64, 2048, 2048]), torch.float32), (torch.Size([64, 2048, 16]), torch.bfloat16), ]
This paste shows one level higher in the grph- the whole attention module. https://www.internalfb.com/phabricator/paste/view/P1229735878
Note the traced code burns float_32 dtype kwarg into the view calls for xq, xk, xv, while the actual model code does not call float32 as part of the view.
I think this is the bug?
An example program shows that torch.export would not burn dtype into the ExportedProgram at trace time:
https://github.com/kwen2501/export-playground/blob/main/dtype.py
See the kwargs for zeros_like
.
$ python dtype.py
opcode name target args kwargs
------------- ------------------ ----------------------- ----------------------- ---------------------
placeholder p_embedding_weight p_embedding_weight () {}
placeholder x x () {}
call_function embedding aten.embedding.default (p_embedding_weight, x) {}
call_function zeros_like aten.zeros_like.default (embedding,) {'pin_memory': False}
output output output ((zeros_like,),) {}
The zeros_like
's dtype in the issue's program is likely the one that causes the dtype mismatch at bmm.
We can use graph_module.print_readable()
to see the original stack trace to identify which part of the code set it to FP32.
Also confirmed that a line like this z = torch.zeros_like(y, dtype=y.dtype)
would burn dtype into the kwargs:
# in forward, code: z = torch.zeros_like(y, dtype=y.dtype)
zeros_like: "f32[2, 4, 3]" = torch.ops.aten.zeros_like.default(embedding, dtype = torch.float32, pin_memory = False);
The doc of torch.zeros_like
says:
torch.zeros_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format)
dtype (torch.dtype, optional) – the desired data type of returned Tensor. Default: if None
, defaults to the dtype of input
.
Thus, it is safe to just write:
z = torch.zeros_like(y)
instead of
z = torch.zeros_like(y, dtype=y.dtype)
AI: we'd need to find out the code that is in the 2nd style above and fix it.
Exporting the llama model and printing the stack shows me that the zeros_like
is from the scaled_dot_product_attention
# File: /data/users/kw2501/torchtitan/torchtitan/models/llama/model.py:203 in forward, code: output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
mul_4: "f32[8, 16, 2048, 16]" = torch.ops.aten.mul.Scalar(transpose, 0.5); transpose = None
ones: "b8[2048, 2048]" = torch.ops.aten.ones.default([2048, 2048], dtype = torch.bool, layout = torch.strided, device = device(type='meta'))
tril: "b8[2048, 2048]" = torch.ops.aten.tril.default(ones); ones = None
zeros_like: "f32[2048, 2048]" = torch.ops.aten.zeros_like.default(tril, dtype = torch.float32)
logical_not: "b8[2048, 2048]" = torch.ops.aten.logical_not.default(tril); tril = None
masked_fill: "f32[2048, 2048]" = torch.ops.aten.masked_fill.Scalar(zeros_like, logical_not, -inf); zeros_like = logical_not = None
transpose_3: "f32[8, 16, 16, 2048]" = torch.ops.aten.transpose.int(transpose_1, -2, -1); transpose_1 = None
mul_5: "f32[8, 16, 16, 2048]" = torch.ops.aten.mul.Scalar(transpose_3, 0.5); transpose_3 = None
expand: "f32[8, 16, 2048, 16]" = torch.ops.aten.expand.default(mul_4, [8, 16, 2048, 16]); mul_4 = None
clone: "f32[8, 16, 2048, 16]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None
_unsafe_view: "f32[128, 2048, 16]" = torch.ops.aten._unsafe_view.default(clone, [128, 2048, 16]); clone = None
expand_1: "f32[8, 16, 16, 2048]" = torch.ops.aten.expand.default(mul_5, [8, 16, 16, 2048]); mul_5 = None
clone_1: "f32[8, 16, 16, 2048]" = torch.ops.aten.clone.default(expand_1, memory_format = torch.contiguous_format); expand_1 = None
_unsafe_view_1: "f32[128, 16, 2048]" = torch.ops.aten._unsafe_view.default(clone_1, [128, 16, 2048]); clone_1 = None
bmm: "f32[128, 2048, 2048]" = torch.ops.aten.bmm.default(_unsafe_view, _unsafe_view_1); _unsafe_view = _unsafe_view_1 = None
view_14: "f32[8, 16, 2048, 2048]" = torch.ops.aten.view.default(bmm, [8, 16, 2048, 2048]); bmm = None
add_1: "f32[8, 16, 2048, 2048]" = torch.ops.aten.add.Tensor(view_14, masked_fill); view_14 = masked_fill = None
_softmax: "f32[8, 16, 2048, 2048]" = torch.ops.aten._softmax.default(add_1, -1, False); add_1 = None
expand_2: "f32[8, 16, 2048, 2048]" = torch.ops.aten.expand.default(_softmax, [8, 16, 2048, 2048]); _softmax = None
view_15: "f32[128, 2048, 2048]" = torch.ops.aten.view.default(expand_2, [128, 2048, 2048]); expand_2 = None
expand_3: "f32[8, 16, 2048, 16]" = torch.ops.aten.expand.default(transpose_2, [8, 16, 2048, 16]); transpose_2 = None
clone_2: "f32[8, 16, 2048, 16]" = torch.ops.aten.clone.default(expand_3, memory_format = torch.contiguous_format); expand_3 = None
_unsafe_view_2: "f32[128, 2048, 16]" = torch.ops.aten._unsafe_view.default(clone_2, [128, 2048, 16]); clone_2 = None
bmm_1: "f32[128, 2048, 16]" = torch.ops.aten.bmm.default(view_15, _unsafe_view_2); view_15 = _unsafe_view_2 = None
view_16: "f32[8, 16, 2048, 16]" = torch.ops.aten.view.default(bmm_1, [8, 16, 2048, 16]); bmm_1 = None
More specifically:
In pytorch/aten/src/ATen/native/transformers/attention.cpp:
Then in convert_boolean_attn_mask
:
CC: @zhxchen17 @tugsbayasgalan let me know you are preparing an improvement to unburn the dtype as well? (in addition to device). We will be thrilled to try that out. CC: @wconstab
Meanwhile, @tugsbayasgalan mentioned that pre-dispatch mode is now the default mode of torch.export. That can also work around this issue by using this new mode to avoid tracing into SPDA.
https://github.com/pytorch/torchtitan/pull/161/files#diff-80b04fce2b861d9470c6160853441793678ca13904dae2a9b8b7145f29cd017aR254
In principle, the issue is that the PP model code traced the non-FSDP model, and in that case, the model code ran a .to(f32) operation which was a no-op and dropped out of the trace, or something like that.
the only proposal i recall was to change the tracer/export to handle this better and not drop the .to operation. Need to check if this has already been resolved.
cc @zhxchen17 @kwen2501