Open HahTK opened 5 months ago
It is a known issue, we don't always cast the XLA OP to the corresponding python dtype because type promotion rule in XLA and pytorch are not always the same. We made the choice to diverge them sometimes to avoid extra casts. In JAX will matmul happens in f32? which means casting both q and k to f32 first?
In JAX, we see a single
HLO op where inputs are BF16 but outputs are FP32. How a backend interprets that is probably a backend specific implementation detail. One reasonable interpretation/implementation of that HLO op is BF16 matmul with accumulation into FP32 accumulation output buffers.
While BF16 matmuls are generally already accumulated in FP32, the final sum is immediately downcast to BF16 to fit it in to a BF16 output buffer. In NVIDIA Nemo implementations, we see custom fused operators created by NVIDIA to perform backward grad matmul and grad_acc in a single op. This has the effect of "possibly" bypassing the intermediate BF16 downcast.
Generalizing : low precision matmuls (BF16 or lower) that chain into other low precision matmuls have no reason to maintain the higher precision result from FP32 accumulation. However, there may be cases where the consumer of the matmul outputs includes other operators that are not low precision matmuls which could benefit from the added precision. The value for something like this will likely go up as precision goes down.
Questions :
Ok now I understand. The difficulty mainly coming from we currently let XLA does its shape inference and never tried to overwrite the XLA's shape decision. for example https://github.com/pytorch/xla/blob/6fadbf5dd37774336f28244441f8ccc799f0b2e9/torch_xla/csrc/ops/bernoulli.cpp#L20-L31
The easier approach would be when PyTorch's shape inference and XLA's shape inference does not agree, we issue any cast
op in the HLO. However this is not what you want, and frankly I don't know whether this is the right thing to do.
The "right approach" might be to special casing the matmul
and somehow overwrite the HLO being generated. I also imagined XLA must has a way for us to tell what output shape we wants and if it is valid. I don't know where that API resides. Open to suggestions and contribution.
2 things :
We definitely do not want a cast op in the final HLO. However, insert a cast op that we can then post process as a HLO pass might be a solution. Is there a place in PT/XLA where we can insert HLO processing passes that would take in some HLO and return a different one?
Looking at the Bernoulli example here
I see a xla::ConvertElementType
. Does that just create a cast operator in HLO? or does that convert the output PrimitiveType without a cast?
I am not sure because, on one hand I see ConvertElementType
mentioned in XLA Op semantics here. That suggests that it is an explicit cast op. On the other hand, when I tried looking at the Cast operator lowering here, I do not see ConvertElementType getting used. So I am wondering how this works.
❓ Questions and Help
When doing torch.matmul(in, other, out=c) with dtype of c is not is only respected on torch but not on XLA. Is this the expected behavior or a bug?
Example
Will generate the follow hlo.pbtxt
Therefore the computation in HLO occurs in BF16. However the output will show
c.dtype = float32
so the output tensor dtype on torch is still correct.Side Note : JAX will respect the c.dtype and lower it to
dot = fp32[512,512]{1,0} dot(transpose, p0), lhs_contracting_dims={1}, rhs_contracting_dims={0}