pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

torch.matmul output buffer dtype is not respected when output dtype is different from input dtype #7160

Open HahTK opened 5 months ago

HahTK commented 5 months ago

❓ Questions and Help

When doing torch.matmul(in, other, out=c) with dtype of c is not is only respected on torch but not on XLA. Is this the expected behavior or a bug?

Example

seq, d = 512, 128 
q = torch.rand((d,seq),dtype=torch.bfloat16, device="cpu").to(device)
k = torch.rand((d,seq), dtype=torch.bfloat16, device="cpu").to(device)
c = torch.zeros((seq,seq), dtype=torch.float32, device="cpu").to(device)
torch.matmul(q.t(),k, out=c)

print (c) 
print (c.dtype)

Will generate the follow hlo.pbtxt

HloModule SyncTensorsGraph.6, entry_computation_layout={(bf16[128,512]{1,0},bf16[128,512]{1,0})->(bf16[512,512]{1,0})}

ENTRY SyncTensorsGraph.6 {
  p1 = bf16[128,512]{1,0} parameter(1), frontend_attributes={neff_input_name="input1"}
  transpose = bf16[512,128]{1,0} transpose(p1), dimensions={1,0}
  p0 = bf16[128,512]{1,0} parameter(0), frontend_attributes={neff_input_name="input0"}
  dot = bf16[512,512]{1,0} dot(transpose, p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  ROOT tuple = (bf16[512,512]{1,0}) tuple(dot), frontend_attributes={neff_output_names="output0"}
}

Therefore the computation in HLO occurs in BF16. However the output will show c.dtype = float32 so the output tensor dtype on torch is still correct.

Side Note : JAX will respect the c.dtype and lower it to dot = fp32[512,512]{1,0} dot(transpose, p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}

JackCaoG commented 5 months ago

It is a known issue, we don't always cast the XLA OP to the corresponding python dtype because type promotion rule in XLA and pytorch are not always the same. We made the choice to diverge them sometimes to avoid extra casts. In JAX will matmul happens in f32? which means casting both q and k to f32 first?

HahTK commented 5 months ago

In JAX, we see a single HLO op where inputs are BF16 but outputs are FP32. How a backend interprets that is probably a backend specific implementation detail. One reasonable interpretation/implementation of that HLO op is BF16 matmul with accumulation into FP32 accumulation output buffers.

While BF16 matmuls are generally already accumulated in FP32, the final sum is immediately downcast to BF16 to fit it in to a BF16 output buffer. In NVIDIA Nemo implementations, we see custom fused operators created by NVIDIA to perform backward grad matmul and grad_acc in a single op. This has the effect of "possibly" bypassing the intermediate BF16 downcast.

Generalizing : low precision matmuls (BF16 or lower) that chain into other low precision matmuls have no reason to maintain the higher precision result from FP32 accumulation. However, there may be cases where the consumer of the matmul outputs includes other operators that are not low precision matmuls which could benefit from the added precision. The value for something like this will likely go up as precision goes down.

Questions :

  1. How hard would it be to support input and output buffers with different precisions in PTXLA? It seems to be legal in HLO
  2. Is there are reason why we do not want to do this? The example I gave above on how a backend might implement this op does not require any additional casting. So it would seem that the promotion rules are not a hinderance?
JackCaoG commented 5 months ago

Ok now I understand. The difficulty mainly coming from we currently let XLA does its shape inference and never tried to overwrite the XLA's shape decision. for example https://github.com/pytorch/xla/blob/6fadbf5dd37774336f28244441f8ccc799f0b2e9/torch_xla/csrc/ops/bernoulli.cpp#L20-L31

The easier approach would be when PyTorch's shape inference and XLA's shape inference does not agree, we issue any cast op in the HLO. However this is not what you want, and frankly I don't know whether this is the right thing to do.

The "right approach" might be to special casing the matmul and somehow overwrite the HLO being generated. I also imagined XLA must has a way for us to tell what output shape we wants and if it is valid. I don't know where that API resides. Open to suggestions and contribution.

HahTK commented 5 months ago

2 things :

  1. We definitely do not want a cast op in the final HLO. However, insert a cast op that we can then post process as a HLO pass might be a solution. Is there a place in PT/XLA where we can insert HLO processing passes that would take in some HLO and return a different one?

  2. Looking at the Bernoulli example here I see a xla::ConvertElementType. Does that just create a cast operator in HLO? or does that convert the output PrimitiveType without a cast?

I am not sure because, on one hand I see ConvertElementType mentioned in XLA Op semantics here. That suggests that it is an explicit cast op. On the other hand, when I tried looking at the Cast operator lowering here, I do not see ConvertElementType getting used. So I am wondering how this works.