Open JackCaoG opened 2 years ago
so the cast was coming from a copy_
2429β void XLATensor::copy_(XLATensorPtr& input, XLATensorPtr& src) {
2430β if (input->GetDevice() == src->GetDevice()) {
2431β torch::lazy::Value copy_value;
2432β if (input->dtype() == src->dtype()) {
2433β copy_value = src->GetIrValue();
2434β } else {
2435β> copy_value = torch::lazy::MakeNode<Cast>(src->GetIrValue(),
2436β input->dtype(), src->dtype());
2437β }
2438β input->SetIrValue(MaybeExpand(copy_value, input->shape()));
2439β }
(gdb) p input->dtype()
$1 = c10::ScalarType::BFloat16
(gdb) p src->dtype()
$2 = c10::ScalarType::Float
The first cast is from f32
-> bf16
which is from the forward graph
IR {
......
%9 = f32[66,44]{1,0} aten::uniform(%8, %7, %6), location=kaiming_uniform_@init.py:412
%10 = bf16[66,44]{1,0} xla::cast(%9), location=convert@module.py:981, type=bf16, dtype=BFloat16, stype=Float, ROOT=1
......
}
second cast is from bf16
-> f32
which is the unnecessary cast we are talking about here. trace it all the way up the call stack I found it is from
270β Tensor to(const Tensor& self, ScalarType dtype, bool non_blocking, bool copy, c10::optional
271β> return to_impl(
272β self,
273β dtype,
274β nullopt,
275β nullopt,
276β nullopt,
277β non_blocking,
278β copy,
279β optional_memory_format);
280β }
where dtype is a f32
(gdb) p dtype
$12 = c10::ScalarType::Float
if I keep going on I finally reach the /pytorch/torch/csrc/autograd/engine.cpp
and saw
741β if (c10::typeMetaToScalarType(metadata.options().dtype()) !=
742β grad.scalar_type()) {
743β> grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype()));
744β }
grad.scalar_type()
is bf16
but this metadata type is f32.. Trying to figure out why..
I think I kind of figure out what happened, if you look at the forward graph, for the cast_after_init
case, it is
IR {
%0 = bf16[] prim::Constant(), location=<module>@debug_bf16.py:13, value=1
%1 = bf16[22,44]{1,0} aten::expand(%0), location=<module>@debug_bf16.py:13, size=(22, 44), ROOT=0
%2 = s64[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
%3 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=214013
%4 = s64[] aten::mul(%3, %2), location=kaiming_uniform_@init.py:412
%5 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=2531011
%6 = s64[] aten::add(%5, %4), location=kaiming_uniform_@init.py:412
%7 = f32[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
%8 = f32[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
%9 = f32[66,44]{1,0} aten::uniform(%8, %7, %6), location=kaiming_uniform_@init.py:412
%10 = bf16[66,44]{1,0} xla::cast(%9), location=convert@module.py:981, type=bf16, dtype=BFloat16, stype=Float, ROOT=1
%11 = bf16[44,66]{0,1} aten::permute(%10), location=forward@linear.py:114, dims=(1, 0)
%12 = bf16[22,66]{1,0} aten::mm(%1, %11), location=forward@linear.py:114, ROOT=2
%13 = bf16[] aten::sum(%12), location=<module>@debug_bf16.py:25, dimensions=(0, 1), keep_reduced_dimensions=0, dtype=15, ROOT=3
}
and the passing the dtype
during init case, it is
IR {
%0 = bf16[] prim::Constant(), location=<module>@debug_bf16.py:13, value=1
%1 = bf16[22,44]{1,0} aten::expand(%0), location=<module>@debug_bf16.py:13, size=(22, 44), ROOT=0
%2 = s64[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
%3 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=214013
%4 = s64[] aten::mul(%3, %2), location=kaiming_uniform_@init.py:412
%5 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=2531011
%6 = s64[] aten::add(%5, %4), location=kaiming_uniform_@init.py:412
%7 = bf16[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
%8 = bf16[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
%9 = bf16[66,44]{1,0} aten::uniform(%8, %7, %6), location=kaiming_uniform_@init.py:412, ROOT=1
%10 = bf16[44,66]{0,1} aten::permute(%9), location=forward@linear.py:114, dims=(1, 0)
%11 = bf16[22,66]{1,0} aten::mm(%1, %10), location=forward@linear.py:114, ROOT=2
%12 = bf16[] aten::sum(%11), location=<module>@debug_bf16.py:25, dimensions=(0, 1), keep_reduced_dimensions=0, dtype=15, ROOT=3
}
The difference is that device data is bf16 in in beginning if dtype
was passed during the init. If we do to(bf16)
after init, device data will be uploaded as f32 but then get casted to bf16
. I think the cast in the backward was a result of the cast
in the forward(grad is in bf16 through, I guess it looks it up and found its input was a f32?).
I see. @JackCaoG Thanks for looking into it!
I think the cast in the backward was a result of the cast in the forward(grad is in bf16 through, I guess it looks it up and found its input was a f32?).
Yeah, this seems to be where the cast to f32 comes in. However, then another weird thing is that when I try further inserting a xm.mark_step()
before the forward pass and putting the linear module construction under torch.no_grad()
, the unnecessary cast in the backward pass still happens, although now the forward pass (traced by autograd) will directly start from bfloat16 device data.
Specifically, I'm running the following to separate construction, forward, and backward
import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()
batchsize = 22
inputsize = 44
outputsize = 66
with torch.no_grad():
# input
x = torch.ones(batchsize, inputsize, dtype=torch.bfloat16, device=device)
# module
linear = torch.nn.Linear(inputsize, outputsize, device=device, bias=False)
linear = linear.to(torch.bfloat16)
xm.mark_step()
y = linear(x)
loss = y.sum()
xm.mark_step()
loss.backward()
xm.mark_step()
which gives 3 IR graphs.
The first one is nn.Linear module construction and cast to bf16, which is under torch.no_grad so shouldn't be traced by autograd:
IR {
%0 = bf16[] prim::Constant(), value=1
%1 = bf16[22,44]{1,0} aten::expand(%0), size=(22, 44), ROOT=0
%2 = s64[] xla::device_data(), device=TPU:0
%3 = s64[] prim::Constant(), value=214013
%4 = s64[] aten::mul(%3, %2)
%5 = s64[] prim::Constant(), value=2531011
%6 = s64[] aten::add(%5, %4)
%7 = f32[] xla::device_data(), device=TPU:0
%8 = f32[] xla::device_data(), device=TPU:0
%9 = f32[66,44]{1,0} aten::uniform(%8, %7, %6)
%10 = bf16[66,44]{1,0} xla::cast(%9), type=bf16, dtype=BFloat16, stype=Float, ROOT=1
}
The second one is the forward pass, which now directly starts with bfloat16 device data:
IR {
%0 = bf16[66,44]{0,1} xla::device_data(), device=TPU:0
%1 = bf16[44,66]{1,0} aten::permute(%0), dims=(1, 0)
%2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
%3 = bf16[22,66]{1,0} aten::mm(%2, %1), ROOT=0
%4 = bf16[] aten::sum(%3), dimensions=(0, 1), keep_reduced_dimensions=0, dtype=15, ROOT=1
}
And the third one is the backward pass, which unfortunately still involves a cast to f32:
IR {
%0 = bf16[] prim::Constant(), value=1
%1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
%2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
%3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
%4 = bf16[44,66]{1,0} aten::mm(%3, %1)
%5 = f32[44,66]{1,0} xla::cast(%4), type=f32, dtype=Float, stype=BFloat16
%6 = f32[66,44]{0,1} aten::permute(%5), dims=(1, 0)
%7 = bf16[66,44]{0,1} xla::cast(%6), type=bf16, dtype=BFloat16, stype=Float, ROOT=0
}
It seems like even in this case, somehow autograd still remembers the old f32 dtype somewhere although we have torch.no_grad()
and an extra xm.mark_step()
before starting the forward pass.
And this cast to f32 in the backward pass not only happens in the first iteration but also in subsequence iterations if we try to do forward and backward multiple times, although now the forward pass starts with a bf16 device data for the linear.weight parameter in %0 = bf16[66,44]{0,1} xla::device_data(), device=TPU:0
So it must be still remembering its old f32 dtype somewhere to retain this behavior in subsequent iterations that don't involve a cast in the forward pass (not sure where it remembers this f32 dtype)
PyTorch autograd engine is a bit beyond my knowledge. @bdhirsh do you know what is metadata in https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/engine.cpp#L743 ?
it is from
for (const auto i : c10::irange(grads.size())) {
const auto& edge = edges[i];
if (!edge.is_valid())
continue;
const auto& metadata = edge.function->input_metadata(edge.input_nr);
but I don't know what this edge means.
hmm @soulitzer, do you know what the purpose of that line in autograd (here) is? Is the idea that if we downcast a tensor in the forward, then we need to make sure to upcast back to its original dtype in the backward? (e.g. float_tensor.to(dtype=torch.bfloat16)
).
I can also take more of a look when I back next Tuesday
FYI it seems this logic was introduced in https://github.com/pytorch/pytorch/commit/88e4cee3e70aac95dd2c18b898808ce3426cb3c9#diff-c66dfeac2a2da1867233047ec413c7e625644c672d7b38b8ec982f5605923c64.
Is the idea that if we downcast a tensor in the forward, then we need to make sure to upcast back to its original dtype in the backward? (e.g. float_tensor.to(dtype=torch.bfloat16)).
Ahh for something like .to
there should already be an autograd formula that explicitly handles that (_to_copy_backward
). Though I guess it is kind of redundant due to the logic in https://github.com/pytorch/pytorch/commit/88e4cee3e70aac95dd2c18b898808ce3426cb3c9#diff-c66dfeac2a2da1867233047ec413c7e625644c672d7b38b8ec982f5605923c64 which relieves the burden of the writer of backward to have to cast to the dtype to the inputs which can be tricky for example when the dtypes of the inputs are not the same.
So, is the solution to disable autograd on the
NVM, I found the relevant .to(bf16)
call, such that the .to
call no longer triggers this auto casting back?.to(bf16)
calls were already decorated with @torch.no_grad()
.
I guess the issue here is that
%6 = f32[66,44]{0,1} aten::permute(%5), dims=(1, 0)
This seems to be why the casting is happening, I am guessing this is the backward of some op?
Yes, it's the backward of F.linear()
.
This seems to be why the casting is happening, I am guessing this is the backward of some op?
Yeah, I think this aten::permute(%5)
is the autograd generated backward of the transpose .t()
in weight.t()
in the forward pass of at::linear
(used by nn.Linear
) from https://github.com/pytorch/pytorch/blob/23bdb570cf05f0cefdacdda5cbf73f58a2e574f4/aten/src/ATen/native/Linear.cpp#L44 (where weight
is a parameter in nn.Linear that has been cast to bfloat16 before the forward pass in our case).
@ronghanghu regarding you question above. mark_step
only affect pytorch/xla view of how tensor is stored, it does not affect autograd engine which is a layer above the pytorch/xla. In other word mark_step
will not clear the auto_grad
state so it kind of make sense why your approach in https://github.com/pytorch/xla/issues/3718#issuecomment-1185083739 doesn't work.
It seems to me that cast to f32 actually is by design and I get why it is needed. Is it something that is blocking?
@JackCaoG Sorry Jack, can you explain why the cast is needed?
So what happens in the forward is the following in pseudo code:
t = tensor(f32)
with no_grad:
t = t.to(bf16)
l = F.linear(t)
So I don't see why the .to(bf16)
should be affected by auto_grad in the backward pass.
The two extra cast ops create two extra tensors. If the tensor size is big or the number of tensors is big, this will create RAM pressure on the device.
So what happens in the forward is the following in pseudo code:
t = tensor(f32) with no_grad: t = t.to(bf16) l = F.linear(t)
So I don't see why the
.to(bf16)
should be affected by auto_grad.
Ah OK, I missed the no_grad
part.
@soulitzer From what I can tell in the cast to f32 in the backward happened because tensor was originally a f32 and this is remember by the auto-grad engine somehow. Given that t.to(bf16)
happens in the no_grad
session, the backward of to
should not be run. Do you know why cast to f32 still happens?
FYI, XLA algebraic optimizer can remove this round trip to f32 during lowering. So it shouldn't affect the performance. Update: I only verified this on GPU and this may become an issue for other backends as pointed out by @hjm-aws.
Original:
ENTRY SyncTensorsGraph.13 {
p0.6 = bf16[22,44]{1,0} parameter(0)
transpose.7 = bf16[44,22]{0,1} transpose(p0.6), dimensions={1,0}
constant.1 = bf16[] constant(1)
reshape.2 = bf16[1,1]{1,0} reshape(constant.1)
broadcast.3 = bf16[1,1]{1,0} broadcast(reshape.2), dimensions={0,1}
reshape.4 = bf16[] reshape(broadcast.3)
broadcast.5 = bf16[22,66]{1,0} broadcast(reshape.4), dimensions={}
dot.8 = bf16[44,66]{1,0} dot(transpose.7, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
convert.9 = f32[44,66]{1,0} convert(dot.8)
transpose.10 = f32[66,44]{0,1} transpose(convert.9), dimensions={1,0}
convert.11 = bf16[66,44]{0,1} convert(transpose.10)
ROOT tuple.12 = (bf16[66,44]{0,1}) tuple(convert.11)
}
After optimization:
ENTRY SyncTensorsGraph.13 {
constant.2 = bf16[] constant(1)
broadcast.5 = bf16[22,66]{1,0} broadcast(constant.2), dimensions={}
p0.6 = bf16[22,44]{1,0} parameter(0)
dot.1 = bf16[66,44]{0,1} dot(broadcast.5, p0.6), lhs_contracting_dims={0}, rhs_contracting_dims={0}
ROOT tuple.12 = (bf16[66,44]{0,1}) tuple(dot.1)
}
Not sure how the code is still working with no_grad
actually, because the output of to
would return a tensor that doesn't require grad.
The cast can happen implicitly during backward even without .to
however. To give a more concrete example of how autograd remembers the dtype of the forward.
a = torch.tensor(1., requires_grad=True, dtype=torch.float32).clone()
b = torch.tensor(1., requires_grad=True, dtype=torch.float64)
(a * b).backward()
clone
, the autograd kernel saves the dtype (float32) of the output of clone onto the graph as part of "input metada"a * b
is float64 due to binary type promotionI'm not actually sure how relevant the above is to the issue at hand though since running the original repro in cpu adding prints reveals that the engine doesn't actually enter the code-path from (4) above.
mark_step only affect pytorch/xla view of how tensor is stored, it does not affect autograd engine which is a layer above the pytorch/xla.
@JackCaoG Yeah, I was aware of this. I was just curious whether there is a way to cast an nn.Module
to bfloat16 in a way that looks just like it's originally constructed in bfloat16 -- so far it seems that putting the cast linear = linear.to(torch.bfloat16)
under no_grad is not enough to eliminate this difference in autograd. (I'm also taking a look at this issue given that the FSDP class also has the compute_dtype that could be bfloat16 and involves such casts a lot).
FYI, XLA algebraic optimizer can remove this round trip to f32 during lowering. So it shouldn't affect the performance.
@ymwangg Thanks, this is great to know :)
Not sure how the code is still working with no_grad actually, because the output of to would return a tensor that doesn't require grad.
@soulitzer I think this is because nn.Module
either
1) explicitly turned the requires_grad
flag back on in https://github.com/pytorch/pytorch/blob/f3f8d96ea69134770198dec485921f9dba45b5ed/torch/nn/modules/module.py#L665
or
2) directly put the casted new tensor into its ".data" in https://github.com/pytorch/pytorch/blob/f3f8d96ea69134770198dec485921f9dba45b5ed/torch/nn/modules/module.py#L660
so if a parameter originally has requires_grad == True
, then it will remain requires_grad == True
after linear = linear.to(torch.bfloat16)
I think this 2nd behavior is probably why we see this issue.
@hjm-aws @JackCaoG @soulitzer OK, I think I figured out the underlying cause of the issue above.
It's because nn.Module
by default cast a parameter to another dtype by directly assigning its .data
to the new casted tensor (in https://github.com/pytorch/pytorch/blob/f3f8d96ea69134770198dec485921f9dba45b5ed/torch/nn/modules/module.py#L660), such as
# what's happening in `linear = linear.to(torch.bfloat16)`:
p.data = p.data.to(torch.bfloat16)
Unfortunately, this will leave the parameter's data being cast to bfloat16, but keep the parameter's metadata in float32. (So now we have a discrepancy in dtype and hence the cast to float32 in the backward IR.)
The solution to this example above is to use torch.__future__.set_overwrite_module_params_on_conversion(True)
. For example, with the codebase below, we now get an elegant backward pass without any cast back to fp32:
import torch
import torch_xla.core.xla_model as xm
torch.__future__.set_overwrite_module_params_on_conversion(True)
device = xm.xla_device()
batchsize = 22
inputsize = 44
outputsize = 66
# input
x = torch.ones(batchsize, inputsize, dtype=torch.bfloat16, device=device)
# module
linear = torch.nn.Linear(inputsize, outputsize, device=device, bias=False)
linear = linear.to(torch.bfloat16)
xm.mark_step()
y = linear(x)
loss = y.sum()
xm.mark_step()
loss.backward()
xm.mark_step()
This has a clean backward pass IR (no cast back to fp32 involved):
IR {
%0 = bf16[] prim::Constant(), value=1
%1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
%2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
%3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
%4 = bf16[44,66]{1,0} aten::mm(%3, %1)
%5 = bf16[66,44]{0,1} aten::permute(%4), dims=(1, 0), ROOT=0
}
@hjm-aws running torch.__future__.set_overwrite_module_params_on_conversion(True)
could fix the example above but not the FSDP case with compute_dtype=torch.bfloat16
. I think I need to roll out a similar solution in FSDP for such cases. But since @ymwangg mentioned that the XLA can eliminate this cast, I guess this is not blocking you :)
@soulitzer I think the root cause is that p.data = p.data.to(torch.bfloat16)
in PyTorch core doesn't update p
's metadata to have bfloat16 dtype, leaving a discrepancy between the metadata dtype and the actual dtype.
Maybe assigning to its .data
should change a tensor's metadata such as dtype, shape, and device? (But that would be a backward-incompatible change in PyTorch core so probably not a good idea either.)
Also, is there a way for a user to explicitly update a tensor p
's meta-data (without making another torch.Tensor
object, so that we can keep the id of p
)?
@ronghanghu Thanks for the investigation! It's really helpful.
As to the optimization passes, I need to verify with our compiler team to ensure these optimization passes can be adopted. I will update the thread either way.
I think the root cause is that p.data = p.data.to(torch.bfloat16) in PyTorch core doesn't update p's metadata to have bfloat16 dtype, leaving a discrepancy between the metadata dtype and the actual dtype.
@ronghanghu Good catch!
I'd point out that metadata can only go out-of-date when p
is not a leaf node. When p
is a leaf node, the following should take care of updating the metadata (input metadata of its grad accumulator): https://github.com/pytorch/pytorch/blob/516f3198d65f6932299d41ddbb98c26de5f0a367/torch/csrc/autograd/variable.cpp#L484-L498
So another way to fix the issue is just to ensure that your parameters are actually leaf nodes before doing .to(dtype)
(or figure out why your parameters aren't leaf nodes):
for p in linear.parameters():
p.detach_().requires_grad_(True)
I'd point out that metadata can only go out-of-date when p is not a leaf node. When p is a leaf node, the following should take care of updating the metadata (input metadata of its grad accumulator): https://github.com/pytorch/pytorch/blob/516f3198d65f6932299d41ddbb98c26de5f0a367/torch/csrc/autograd/variable.cpp#L484-L498
So another way to fix the issue is just to ensure that your parameters are actually leaf nodes before doing .to(dtype) (or figure out why your parameters aren't leaf nodes):
@soulitzer Thanks for the explanation! However, I checked the example above, and found that its parameter (linear.weight
) is already a leaf node. Actually, it is a leaf node both before and after the cast linear = linear.to(torch.bfloat16)
regardless of whether I use torch.__future__.set_overwrite_module_params_on_conversion(True)
in https://github.com/pytorch/xla/issues/3718#issuecomment-1190696341.
This issue is also reproducible without using nn.Linear
or any nn.Module
, but only using torch.Tensor
. For example, if I first construct a weight
tensor in float32 and then cast it to bfloat16, there will be an unexpected cast to float32 in the backward pass IR:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
cast_after_init = True
# cast_after_init = False
# weight
if cast_after_init:
weight = torch.ones(66, 44, device=device, requires_grad=True)
assert weight.is_leaf
with torch.no_grad():
weight.data = weight.data.to(torch.bfloat16)
else:
weight = torch.ones(66, 44, device=device, requires_grad=True, dtype=torch.bfloat16)
xm.mark_step()
# forward pass
x = torch.ones(22, 44, dtype=torch.bfloat16, device=device)
y = torch.nn.functional.linear(x, weight)
loss = y.sum()
xm.mark_step()
# backward pass
loss.backward()
print(f"cast_after_init: {cast_after_init}")
print("backward:", torch_xla._XLAC._get_xla_tensors_text([weight.grad]))
xm.mark_step()
which prints
cast_after_init: True
backward: IR {
%0 = bf16[] prim::Constant(), value=1
%1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
%2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
%3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
%4 = bf16[44,66]{1,0} aten::mm(%3, %1)
%5 = f32[44,66]{1,0} xla::cast(%4), type=f32, dtype=Float, stype=BFloat16
%6 = f32[66,44]{0,1} aten::permute(%5), dims=(1, 0)
%7 = bf16[66,44]{0,1} xla::cast(%6), type=bf16, dtype=BFloat16, stype=Float, ROOT=0
}
On the contrary, if I directly construct a weight
tensor in bfloat16 (i.e. set cast_after_init=False
in the script above), then the cast to float32 doesn't happen in the backward pass, and it prints
cast_after_init: False
backward: IR {
%0 = bf16[] prim::Constant(), value=1
%1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
%2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
%3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
%4 = bf16[44,66]{1,0} aten::mm(%3, %1)
%5 = bf16[66,44]{0,1} aten::permute(%4), dims=(1, 0), ROOT=0
}
Since the only difference between these two cases above is whether weight
is cast to or directly constructed in bfloat16, I suspect the cast weight.data = weight.data.to(torch.bfloat16)
in the first case of cast_after_init=True
didn't correctly update the metadata of the weight
tensor to bfloat16 dtype, even though weight
is a leaf node here.
Yeah I'm not too sure what is going on - likely XLA specific? I reproed this on Colab (1.11), but from adding the following check it looks like the input_metadata of the weight is bfloat16 as expected after setting the .data.
import os
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch.utils.cpp_extension import load_inline
source = """
#include "ATen/core/Tensor.h"
int get_metadata_dtype(at::Tensor& t) {
auto tmp = t.clone();
return static_cast<int>(c10::typeMetaToScalarType(tmp.grad_fn()->next_edge(0).function->input_metadata(0).dtype()));
}
"""
inline = load_inline(
name="inline_extension",
cpp_sources=[source],
functions=["get_metadata_dtype"],
extra_include_paths=['/usr/local/lib/python3.7/dist-packages/torch/include'],
build_directory="/content/torch_cpp_extension"
)
device = xm.xla_device()
cast_after_init = True
# cast_after_init = False
# weight
if cast_after_init:
weight = torch.ones(66, 44, device=device, requires_grad=True)
print(inline.get_metadata_dtype(weight)) # 6 = float
assert weight.is_leaf
with torch.no_grad():
weight.data = weight.data.to(torch.bfloat16)
print(inline.get_metadata_dtype(weight)) # 15 = bfloat16 (if non-leaf, this is also 6)
else:
weight = torch.ones(66, 44, device=device, requires_grad=True, dtype=torch.bfloat16)
# forward pass
x = torch.ones(22, 44, dtype=torch.bfloat16, device=device)
y = torch.nn.functional.linear(x, weight)
loss = y.sum()
xm.mark_step()
# backward pass
loss.backward()
print(f"cast_after_init: {cast_after_init}")
print("backward:", torch_xla._XLAC._get_xla_tensors_text([weight.grad]))
xm.mark_step()
But since @ymwangg mentioned that the XLA can eliminate this cast, I guess this is not blocking you :)
@ronghanghu Ronghang, I can confirm that our compiler can get rid of the double cast.
This also occurs in AMP scenarios, even after setting torch.__future__.set_overwrite_module_params_on_conversion(True)
@ronghanghu
for example
import torch
import torch_xla.core.xla_model as xm
from torch_xla.amp import autocast
torch.__future__.set_overwrite_module_params_on_conversion(True)
device = xm.xla_device()
batchsize = 22
inputsize = 44
outputsize = 66
# input
x = torch.ones(batchsize, inputsize, dtype=torch.float, device=device)
# module
linear = torch.nn.Linear(inputsize, outputsize, device=device, bias=False)
xm.mark_step()
with autocast(device, dtype=torch.bfloat16):
y = linear(x)
loss = y.sum()
xm.mark_step()
loss.backward()
xm.mark_step()
, the backward IR is:
IR {
%0 = f32[] prim::Constant(), xla_shape=f32[], value=1
%1 = f32[] aten::expand(%0), xla_shape=f32[], size=(), dynamic_dims=()
%2 = f32[22,66]{1,0} aten::expand(%1), xla_shape=f32[22,66]{1,0}, size=(22, 66)
%3 = bf16[22,66]{1,0} xla::cast(%2), xla_shape=bf16[22,66]{1,0}, type=bf16, dtype=BFloat16, stype=Float
%4 = bf16[22,44]{1,0} xla::device_data(), xla_shape=bf16[22,44]{1,0}, device=GPU:0
%5 = bf16[44,22]{0,1} aten::permute(%4), xla_shape=bf16[44,22]{0,1}, dims=(1, 0)
%6 = bf16[44,66]{1,0} aten::mm(%5, %3), xla_shape=bf16[44,66]{1,0}
%7 = bf16[66,44]{0,1} aten::permute(%6), xla_shape=bf16[66,44]{0,1}, dims=(1, 0)
%8 = f32[66,44]{0,1} xla::cast(%7), xla_shape=f32[66,44]{0,1}, type=f32, dtype=Float, stype=BFloat16, ROOT=0
}
There is an extra cast %8 in the IR.
π Bug
Steps to reproduce the behavior:
run
with
to dump the IR. The observation is that with Case 1 cast_after_init = True, the backward pass IR (corresponding to loss.backward() between the two mark_step) somehow has an extra cast back to float32 around the 2nd permute:
while with Case 2 cast_after_init = False, the backward pass IR was always in bfloat16 as expected:
Expected behavior
no f32 is provided.
Environment
Additional context
reported by @ronghanghu and @hjm-aws