pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Cast liner model to bf16 produces unexpected f32 #3718

Open JackCaoG opened 2 years ago

JackCaoG commented 2 years ago

πŸ› Bug

Steps to reproduce the behavior:

run

# save this file as "debug_bf16.py"

import torch
import torch_xla.core.xla_model as xm

cast_after_init = True
# cast_after_init = False

device = xm.xla_device()
batchsize = 22
inputsize = 44
outputsize = 66

# input
x = torch.ones(batchsize, inputsize, dtype=torch.bfloat16, device=device)

# module
if cast_after_init:
    # Case 1
    linear = torch.nn.Linear(inputsize, outputsize, device=device, bias=False)
    linear = linear.to(torch.bfloat16)
else:
    # Case 2
    linear = torch.nn.Linear(inputsize, outputsize, device=device, dtype=torch.bfloat16, bias=False)

y = linear(x)
loss = y.sum()

xm.mark_step()
loss.backward()
xm.mark_step()

with

XLA_SAVE_TENSORS_FILE=./debug_bf16_ir.txt \
XLA_SAVE_TENSORS_FMT=text \
python3 ./debug_bf16.py

to dump the IR. The observation is that with Case 1 cast_after_init = True, the backward pass IR (corresponding to loss.backward() between the two mark_step) somehow has an extra cast back to float32 around the 2nd permute:

IR {
  %0 = bf16[] prim::Constant(), value=1
  %1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
  %2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
  %3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
  %4 = bf16[44,66]{1,0} aten::mm(%3, %1)
  %5 = f32[44,66]{1,0} xla::cast(%4), type=f32, dtype=Float, stype=BFloat16
  %6 = f32[66,44]{0,1} aten::permute(%5), dims=(1, 0)
  %7 = bf16[66,44]{0,1} xla::cast(%6), type=bf16, dtype=BFloat16, stype=Float, ROOT=0
}

while with Case 2 cast_after_init = False, the backward pass IR was always in bfloat16 as expected:

IR {
  %0 = bf16[] prim::Constant(), value=1
  %1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
  %2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
  %3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
  %4 = bf16[44,66]{1,0} aten::mm(%3, %1)
  %5 = bf16[66,44]{0,1} aten::permute(%4), dims=(1, 0), ROOT=0
}

Expected behavior

no f32 is provided.

Environment

Additional context

reported by @ronghanghu and @hjm-aws

JackCaoG commented 2 years ago

so the cast was coming from a copy_

2429β”‚ void XLATensor::copy_(XLATensorPtr& input, XLATensorPtr& src) {
2430β”‚   if (input->GetDevice() == src->GetDevice()) {
2431β”‚     torch::lazy::Value copy_value;
2432β”‚     if (input->dtype() == src->dtype()) {
2433β”‚       copy_value = src->GetIrValue();
2434β”‚     } else {
2435β”œ>      copy_value = torch::lazy::MakeNode<Cast>(src->GetIrValue(),
2436β”‚                                                input->dtype(), src->dtype());
2437β”‚     }
2438β”‚     input->SetIrValue(MaybeExpand(copy_value, input->shape()));
2439β”‚   }
(gdb) p input->dtype()
$1 = c10::ScalarType::BFloat16
(gdb) p src->dtype()
$2 = c10::ScalarType::Float

The first cast is from f32 -> bf16 which is from the forward graph

IR {
......
  %9 = f32[66,44]{1,0} aten::uniform(%8, %7, %6), location=kaiming_uniform_@init.py:412
  %10 = bf16[66,44]{1,0} xla::cast(%9), location=convert@module.py:981, type=bf16, dtype=BFloat16, stype=Float, ROOT=1
......
}

second cast is from bf16 -> f32 which is the unnecessary cast we are talking about here. trace it all the way up the call stack I found it is from

 270β”‚ Tensor to(const Tensor& self, ScalarType dtype, bool non_blocking, bool copy, c10::optional
 271β”œ>  return to_impl(
 272β”‚       self,
 273β”‚       dtype,
 274β”‚       nullopt,
 275β”‚       nullopt,
 276β”‚       nullopt,
 277β”‚       non_blocking,
 278β”‚       copy,
 279β”‚       optional_memory_format);
 280β”‚ }

where dtype is a f32

(gdb) p dtype
$12 = c10::ScalarType::Float

if I keep going on I finally reach the /pytorch/torch/csrc/autograd/engine.cpp and saw

 741β”‚     if (c10::typeMetaToScalarType(metadata.options().dtype()) !=
 742β”‚         grad.scalar_type()) {
 743β”œ>      grad = grad.to(c10::typeMetaToScalarType(metadata.options().dtype()));
 744β”‚     }

grad.scalar_type() is bf16 but this metadata type is f32.. Trying to figure out why..

JackCaoG commented 2 years ago

I think I kind of figure out what happened, if you look at the forward graph, for the cast_after_init case, it is

IR {
  %0 = bf16[] prim::Constant(), location=<module>@debug_bf16.py:13, value=1
  %1 = bf16[22,44]{1,0} aten::expand(%0), location=<module>@debug_bf16.py:13, size=(22, 44), ROOT=0
  %2 = s64[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
  %3 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=214013
  %4 = s64[] aten::mul(%3, %2), location=kaiming_uniform_@init.py:412
  %5 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=2531011
  %6 = s64[] aten::add(%5, %4), location=kaiming_uniform_@init.py:412
  %7 = f32[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
  %8 = f32[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
  %9 = f32[66,44]{1,0} aten::uniform(%8, %7, %6), location=kaiming_uniform_@init.py:412
  %10 = bf16[66,44]{1,0} xla::cast(%9), location=convert@module.py:981, type=bf16, dtype=BFloat16, stype=Float, ROOT=1
  %11 = bf16[44,66]{0,1} aten::permute(%10), location=forward@linear.py:114, dims=(1, 0)
  %12 = bf16[22,66]{1,0} aten::mm(%1, %11), location=forward@linear.py:114, ROOT=2
  %13 = bf16[] aten::sum(%12), location=<module>@debug_bf16.py:25, dimensions=(0, 1), keep_reduced_dimensions=0, dtype=15, ROOT=3
}

and the passing the dtype during init case, it is

IR {
  %0 = bf16[] prim::Constant(), location=<module>@debug_bf16.py:13, value=1
  %1 = bf16[22,44]{1,0} aten::expand(%0), location=<module>@debug_bf16.py:13, size=(22, 44), ROOT=0
  %2 = s64[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
  %3 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=214013
  %4 = s64[] aten::mul(%3, %2), location=kaiming_uniform_@init.py:412
  %5 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=2531011
  %6 = s64[] aten::add(%5, %4), location=kaiming_uniform_@init.py:412
  %7 = bf16[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
  %8 = bf16[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=CPU:0
  %9 = bf16[66,44]{1,0} aten::uniform(%8, %7, %6), location=kaiming_uniform_@init.py:412, ROOT=1
  %10 = bf16[44,66]{0,1} aten::permute(%9), location=forward@linear.py:114, dims=(1, 0)
  %11 = bf16[22,66]{1,0} aten::mm(%1, %10), location=forward@linear.py:114, ROOT=2
  %12 = bf16[] aten::sum(%11), location=<module>@debug_bf16.py:25, dimensions=(0, 1), keep_reduced_dimensions=0, dtype=15, ROOT=3
}

The difference is that device data is bf16 in in beginning if dtype was passed during the init. If we do to(bf16) after init, device data will be uploaded as f32 but then get casted to bf16. I think the cast in the backward was a result of the cast in the forward(grad is in bf16 through, I guess it looks it up and found its input was a f32?).

ronghanghu commented 2 years ago

I see. @JackCaoG Thanks for looking into it!

I think the cast in the backward was a result of the cast in the forward(grad is in bf16 through, I guess it looks it up and found its input was a f32?).

Yeah, this seems to be where the cast to f32 comes in. However, then another weird thing is that when I try further inserting a xm.mark_step() before the forward pass and putting the linear module construction under torch.no_grad(), the unnecessary cast in the backward pass still happens, although now the forward pass (traced by autograd) will directly start from bfloat16 device data.

Specifically, I'm running the following to separate construction, forward, and backward

import torch
import torch_xla.core.xla_model as xm

device = xm.xla_device()
batchsize = 22
inputsize = 44
outputsize = 66

with torch.no_grad():
    # input
    x = torch.ones(batchsize, inputsize, dtype=torch.bfloat16, device=device)

    # module
    linear = torch.nn.Linear(inputsize, outputsize, device=device, bias=False)
    linear = linear.to(torch.bfloat16)

xm.mark_step()
y = linear(x)
loss = y.sum()

xm.mark_step()
loss.backward()
xm.mark_step()

which gives 3 IR graphs.

The first one is nn.Linear module construction and cast to bf16, which is under torch.no_grad so shouldn't be traced by autograd:

IR {
  %0 = bf16[] prim::Constant(), value=1
  %1 = bf16[22,44]{1,0} aten::expand(%0), size=(22, 44), ROOT=0
  %2 = s64[] xla::device_data(), device=TPU:0
  %3 = s64[] prim::Constant(), value=214013
  %4 = s64[] aten::mul(%3, %2)
  %5 = s64[] prim::Constant(), value=2531011
  %6 = s64[] aten::add(%5, %4)
  %7 = f32[] xla::device_data(), device=TPU:0
  %8 = f32[] xla::device_data(), device=TPU:0
  %9 = f32[66,44]{1,0} aten::uniform(%8, %7, %6)
  %10 = bf16[66,44]{1,0} xla::cast(%9), type=bf16, dtype=BFloat16, stype=Float, ROOT=1
}

The second one is the forward pass, which now directly starts with bfloat16 device data:

IR {
  %0 = bf16[66,44]{0,1} xla::device_data(), device=TPU:0
  %1 = bf16[44,66]{1,0} aten::permute(%0), dims=(1, 0)
  %2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
  %3 = bf16[22,66]{1,0} aten::mm(%2, %1), ROOT=0
  %4 = bf16[] aten::sum(%3), dimensions=(0, 1), keep_reduced_dimensions=0, dtype=15, ROOT=1
}

And the third one is the backward pass, which unfortunately still involves a cast to f32:

IR {
  %0 = bf16[] prim::Constant(), value=1
  %1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
  %2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
  %3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
  %4 = bf16[44,66]{1,0} aten::mm(%3, %1)
  %5 = f32[44,66]{1,0} xla::cast(%4), type=f32, dtype=Float, stype=BFloat16
  %6 = f32[66,44]{0,1} aten::permute(%5), dims=(1, 0)
  %7 = bf16[66,44]{0,1} xla::cast(%6), type=bf16, dtype=BFloat16, stype=Float, ROOT=0
}

It seems like even in this case, somehow autograd still remembers the old f32 dtype somewhere although we have torch.no_grad() and an extra xm.mark_step() before starting the forward pass.

ronghanghu commented 2 years ago

And this cast to f32 in the backward pass not only happens in the first iteration but also in subsequence iterations if we try to do forward and backward multiple times, although now the forward pass starts with a bf16 device data for the linear.weight parameter in %0 = bf16[66,44]{0,1} xla::device_data(), device=TPU:0

So it must be still remembering its old f32 dtype somewhere to retain this behavior in subsequent iterations that don't involve a cast in the forward pass (not sure where it remembers this f32 dtype)

JackCaoG commented 2 years ago

PyTorch autograd engine is a bit beyond my knowledge. @bdhirsh do you know what is metadata in https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/engine.cpp#L743 ?

it is from

  for (const auto i : c10::irange(grads.size())) {
    const auto& edge = edges[i];
    if (!edge.is_valid())
      continue;

    const auto& metadata = edge.function->input_metadata(edge.input_nr);

but I don't know what this edge means.

bdhirsh commented 2 years ago

hmm @soulitzer, do you know what the purpose of that line in autograd (here) is? Is the idea that if we downcast a tensor in the forward, then we need to make sure to upcast back to its original dtype in the backward? (e.g. float_tensor.to(dtype=torch.bfloat16)).

I can also take more of a look when I back next Tuesday

hjm-aws commented 2 years ago

FYI it seems this logic was introduced in https://github.com/pytorch/pytorch/commit/88e4cee3e70aac95dd2c18b898808ce3426cb3c9#diff-c66dfeac2a2da1867233047ec413c7e625644c672d7b38b8ec982f5605923c64.

soulitzer commented 2 years ago

Is the idea that if we downcast a tensor in the forward, then we need to make sure to upcast back to its original dtype in the backward? (e.g. float_tensor.to(dtype=torch.bfloat16)).

Ahh for something like .to there should already be an autograd formula that explicitly handles that (_to_copy_backward). Though I guess it is kind of redundant due to the logic in https://github.com/pytorch/pytorch/commit/88e4cee3e70aac95dd2c18b898808ce3426cb3c9#diff-c66dfeac2a2da1867233047ec413c7e625644c672d7b38b8ec982f5605923c64 which relieves the burden of the writer of backward to have to cast to the dtype to the inputs which can be tricky for example when the dtypes of the inputs are not the same.

hjm-aws commented 2 years ago

So, is the solution to disable autograd on the .to(bf16) call, such that the .to call no longer triggers this auto casting back? NVM, I found the relevant .to(bf16) calls were already decorated with @torch.no_grad().

JackCaoG commented 2 years ago

I guess the issue here is that

  %6 = f32[66,44]{0,1} aten::permute(%5), dims=(1, 0)

This seems to be why the casting is happening, I am guessing this is the backward of some op?

hjm-aws commented 2 years ago

Yes, it's the backward of F.linear().

ronghanghu commented 2 years ago

This seems to be why the casting is happening, I am guessing this is the backward of some op?

Yeah, I think this aten::permute(%5) is the autograd generated backward of the transpose .t() in weight.t() in the forward pass of at::linear (used by nn.Linear) from https://github.com/pytorch/pytorch/blob/23bdb570cf05f0cefdacdda5cbf73f58a2e574f4/aten/src/ATen/native/Linear.cpp#L44 (where weight is a parameter in nn.Linear that has been cast to bfloat16 before the forward pass in our case).

JackCaoG commented 2 years ago

@ronghanghu regarding you question above. mark_step only affect pytorch/xla view of how tensor is stored, it does not affect autograd engine which is a layer above the pytorch/xla. In other word mark_step will not clear the auto_grad state so it kind of make sense why your approach in https://github.com/pytorch/xla/issues/3718#issuecomment-1185083739 doesn't work.

It seems to me that cast to f32 actually is by design and I get why it is needed. Is it something that is blocking?

hjm-aws commented 2 years ago

@JackCaoG Sorry Jack, can you explain why the cast is needed?

So what happens in the forward is the following in pseudo code:

t = tensor(f32)
with no_grad:
  t = t.to(bf16)
l = F.linear(t)

So I don't see why the .to(bf16) should be affected by auto_grad in the backward pass.

The two extra cast ops create two extra tensors. If the tensor size is big or the number of tensors is big, this will create RAM pressure on the device.

JackCaoG commented 2 years ago

So what happens in the forward is the following in pseudo code:

t = tensor(f32)
with no_grad:
  t = t.to(bf16)
l = F.linear(t)

So I don't see why the .to(bf16) should be affected by auto_grad.

Ah OK, I missed the no_grad part.

@soulitzer From what I can tell in the cast to f32 in the backward happened because tensor was originally a f32 and this is remember by the auto-grad engine somehow. Given that t.to(bf16) happens in the no_grad session, the backward of to should not be run. Do you know why cast to f32 still happens?

ymwangg commented 2 years ago

FYI, XLA algebraic optimizer can remove this round trip to f32 during lowering. So it shouldn't affect the performance. Update: I only verified this on GPU and this may become an issue for other backends as pointed out by @hjm-aws.

Original:

ENTRY SyncTensorsGraph.13 {
  p0.6 = bf16[22,44]{1,0} parameter(0)
  transpose.7 = bf16[44,22]{0,1} transpose(p0.6), dimensions={1,0}
  constant.1 = bf16[] constant(1)
  reshape.2 = bf16[1,1]{1,0} reshape(constant.1)
  broadcast.3 = bf16[1,1]{1,0} broadcast(reshape.2), dimensions={0,1}
  reshape.4 = bf16[] reshape(broadcast.3)
  broadcast.5 = bf16[22,66]{1,0} broadcast(reshape.4), dimensions={}
  dot.8 = bf16[44,66]{1,0} dot(transpose.7, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}
  convert.9 = f32[44,66]{1,0} convert(dot.8)
  transpose.10 = f32[66,44]{0,1} transpose(convert.9), dimensions={1,0}
  convert.11 = bf16[66,44]{0,1} convert(transpose.10)
  ROOT tuple.12 = (bf16[66,44]{0,1}) tuple(convert.11)
}

After optimization:

ENTRY SyncTensorsGraph.13 {
  constant.2 = bf16[] constant(1)
  broadcast.5 = bf16[22,66]{1,0} broadcast(constant.2), dimensions={}
  p0.6 = bf16[22,44]{1,0} parameter(0)
  dot.1 = bf16[66,44]{0,1} dot(broadcast.5, p0.6), lhs_contracting_dims={0}, rhs_contracting_dims={0}
  ROOT tuple.12 = (bf16[66,44]{0,1}) tuple(dot.1)
}
soulitzer commented 2 years ago

Not sure how the code is still working with no_grad actually, because the output of to would return a tensor that doesn't require grad.

The cast can happen implicitly during backward even without .to however. To give a more concrete example of how autograd remembers the dtype of the forward.

a = torch.tensor(1., requires_grad=True, dtype=torch.float32).clone()
b = torch.tensor(1., requires_grad=True, dtype=torch.float64)

(a * b).backward()
  1. During forward, after clone, the autograd kernel saves the dtype (float32) of the output of clone onto the graph as part of "input metada"
  2. output of a * b is float64 due to binary type promotion
  3. perform backward: the first node is mul, and we perform mul_backward with float64
  4. after every computation during backward, engine checks input metadata of the next node in the graph (which is clone in this case) and tries to match the metadata. Here we cast back to float32.

I'm not actually sure how relevant the above is to the issue at hand though since running the original repro in cpu adding prints reveals that the engine doesn't actually enter the code-path from (4) above.

ronghanghu commented 2 years ago

mark_step only affect pytorch/xla view of how tensor is stored, it does not affect autograd engine which is a layer above the pytorch/xla.

@JackCaoG Yeah, I was aware of this. I was just curious whether there is a way to cast an nn.Module to bfloat16 in a way that looks just like it's originally constructed in bfloat16 -- so far it seems that putting the cast linear = linear.to(torch.bfloat16) under no_grad is not enough to eliminate this difference in autograd. (I'm also taking a look at this issue given that the FSDP class also has the compute_dtype that could be bfloat16 and involves such casts a lot).


FYI, XLA algebraic optimizer can remove this round trip to f32 during lowering. So it shouldn't affect the performance.

@ymwangg Thanks, this is great to know :)


Not sure how the code is still working with no_grad actually, because the output of to would return a tensor that doesn't require grad.

@soulitzer I think this is because nn.Module either

1) explicitly turned the requires_grad flag back on in https://github.com/pytorch/pytorch/blob/f3f8d96ea69134770198dec485921f9dba45b5ed/torch/nn/modules/module.py#L665

or

2) directly put the casted new tensor into its ".data" in https://github.com/pytorch/pytorch/blob/f3f8d96ea69134770198dec485921f9dba45b5ed/torch/nn/modules/module.py#L660

so if a parameter originally has requires_grad == True, then it will remain requires_grad == True after linear = linear.to(torch.bfloat16)

I think this 2nd behavior is probably why we see this issue.

ronghanghu commented 2 years ago

@hjm-aws @JackCaoG @soulitzer OK, I think I figured out the underlying cause of the issue above.

It's because nn.Module by default cast a parameter to another dtype by directly assigning its .data to the new casted tensor (in https://github.com/pytorch/pytorch/blob/f3f8d96ea69134770198dec485921f9dba45b5ed/torch/nn/modules/module.py#L660), such as

# what's happening in `linear = linear.to(torch.bfloat16)`:
p.data = p.data.to(torch.bfloat16)

Unfortunately, this will leave the parameter's data being cast to bfloat16, but keep the parameter's metadata in float32. (So now we have a discrepancy in dtype and hence the cast to float32 in the backward IR.)

The solution to this example above is to use torch.__future__.set_overwrite_module_params_on_conversion(True). For example, with the codebase below, we now get an elegant backward pass without any cast back to fp32:

import torch
import torch_xla.core.xla_model as xm

torch.__future__.set_overwrite_module_params_on_conversion(True)

device = xm.xla_device()
batchsize = 22
inputsize = 44
outputsize = 66

# input
x = torch.ones(batchsize, inputsize, dtype=torch.bfloat16, device=device)

# module
linear = torch.nn.Linear(inputsize, outputsize, device=device, bias=False)
linear = linear.to(torch.bfloat16)

xm.mark_step()
y = linear(x)
loss = y.sum()

xm.mark_step()
loss.backward()
xm.mark_step()

This has a clean backward pass IR (no cast back to fp32 involved):

IR {
  %0 = bf16[] prim::Constant(), value=1
  %1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
  %2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
  %3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
  %4 = bf16[44,66]{1,0} aten::mm(%3, %1)
  %5 = bf16[66,44]{0,1} aten::permute(%4), dims=(1, 0), ROOT=0
}

@hjm-aws running torch.__future__.set_overwrite_module_params_on_conversion(True) could fix the example above but not the FSDP case with compute_dtype=torch.bfloat16. I think I need to roll out a similar solution in FSDP for such cases. But since @ymwangg mentioned that the XLA can eliminate this cast, I guess this is not blocking you :)


@soulitzer I think the root cause is that p.data = p.data.to(torch.bfloat16) in PyTorch core doesn't update p's metadata to have bfloat16 dtype, leaving a discrepancy between the metadata dtype and the actual dtype.

Maybe assigning to its .data should change a tensor's metadata such as dtype, shape, and device? (But that would be a backward-incompatible change in PyTorch core so probably not a good idea either.)

Also, is there a way for a user to explicitly update a tensor p's meta-data (without making another torch.Tensor object, so that we can keep the id of p)?

hjm-aws commented 2 years ago

@ronghanghu Thanks for the investigation! It's really helpful.

As to the optimization passes, I need to verify with our compiler team to ensure these optimization passes can be adopted. I will update the thread either way.

soulitzer commented 2 years ago

I think the root cause is that p.data = p.data.to(torch.bfloat16) in PyTorch core doesn't update p's metadata to have bfloat16 dtype, leaving a discrepancy between the metadata dtype and the actual dtype.

@ronghanghu Good catch!

I'd point out that metadata can only go out-of-date when p is not a leaf node. When p is a leaf node, the following should take care of updating the metadata (input metadata of its grad accumulator): https://github.com/pytorch/pytorch/blob/516f3198d65f6932299d41ddbb98c26de5f0a367/torch/csrc/autograd/variable.cpp#L484-L498

So another way to fix the issue is just to ensure that your parameters are actually leaf nodes before doing .to(dtype) (or figure out why your parameters aren't leaf nodes):

for p in linear.parameters():
    p.detach_().requires_grad_(True)
ronghanghu commented 2 years ago

I'd point out that metadata can only go out-of-date when p is not a leaf node. When p is a leaf node, the following should take care of updating the metadata (input metadata of its grad accumulator): https://github.com/pytorch/pytorch/blob/516f3198d65f6932299d41ddbb98c26de5f0a367/torch/csrc/autograd/variable.cpp#L484-L498

So another way to fix the issue is just to ensure that your parameters are actually leaf nodes before doing .to(dtype) (or figure out why your parameters aren't leaf nodes):

@soulitzer Thanks for the explanation! However, I checked the example above, and found that its parameter (linear.weight) is already a leaf node. Actually, it is a leaf node both before and after the cast linear = linear.to(torch.bfloat16) regardless of whether I use torch.__future__.set_overwrite_module_params_on_conversion(True) in https://github.com/pytorch/xla/issues/3718#issuecomment-1190696341.


This issue is also reproducible without using nn.Linear or any nn.Module, but only using torch.Tensor. For example, if I first construct a weight tensor in float32 and then cast it to bfloat16, there will be an unexpected cast to float32 in the backward pass IR:

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

cast_after_init = True
# cast_after_init = False

# weight
if cast_after_init:
    weight = torch.ones(66, 44, device=device, requires_grad=True)
    assert weight.is_leaf
    with torch.no_grad():
        weight.data = weight.data.to(torch.bfloat16)
else:
    weight = torch.ones(66, 44, device=device, requires_grad=True, dtype=torch.bfloat16)
xm.mark_step()

# forward pass
x = torch.ones(22, 44, dtype=torch.bfloat16, device=device)
y = torch.nn.functional.linear(x, weight)
loss = y.sum()
xm.mark_step()

# backward pass
loss.backward()
print(f"cast_after_init: {cast_after_init}")
print("backward:", torch_xla._XLAC._get_xla_tensors_text([weight.grad]))
xm.mark_step()

which prints

cast_after_init: True
backward: IR {
  %0 = bf16[] prim::Constant(), value=1
  %1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
  %2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
  %3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
  %4 = bf16[44,66]{1,0} aten::mm(%3, %1)
  %5 = f32[44,66]{1,0} xla::cast(%4), type=f32, dtype=Float, stype=BFloat16
  %6 = f32[66,44]{0,1} aten::permute(%5), dims=(1, 0)
  %7 = bf16[66,44]{0,1} xla::cast(%6), type=bf16, dtype=BFloat16, stype=Float, ROOT=0
}

On the contrary, if I directly construct a weight tensor in bfloat16 (i.e. set cast_after_init=False in the script above), then the cast to float32 doesn't happen in the backward pass, and it prints

cast_after_init: False
backward: IR {
  %0 = bf16[] prim::Constant(), value=1
  %1 = bf16[22,66]{1,0} aten::expand(%0), size=(22, 66)
  %2 = bf16[22,44]{1,0} xla::device_data(), device=TPU:0
  %3 = bf16[44,22]{0,1} aten::permute(%2), dims=(1, 0)
  %4 = bf16[44,66]{1,0} aten::mm(%3, %1)
  %5 = bf16[66,44]{0,1} aten::permute(%4), dims=(1, 0), ROOT=0
}

Since the only difference between these two cases above is whether weight is cast to or directly constructed in bfloat16, I suspect the cast weight.data = weight.data.to(torch.bfloat16) in the first case of cast_after_init=True didn't correctly update the metadata of the weight tensor to bfloat16 dtype, even though weight is a leaf node here.

soulitzer commented 2 years ago

Yeah I'm not too sure what is going on - likely XLA specific? I reproed this on Colab (1.11), but from adding the following check it looks like the input_metadata of the weight is bfloat16 as expected after setting the .data.

import os
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch.utils.cpp_extension import load_inline
source = """
  #include "ATen/core/Tensor.h"
  int get_metadata_dtype(at::Tensor& t) {
    auto tmp = t.clone();
    return static_cast<int>(c10::typeMetaToScalarType(tmp.grad_fn()->next_edge(0).function->input_metadata(0).dtype()));
  }
"""

inline = load_inline(
    name="inline_extension",
    cpp_sources=[source],
    functions=["get_metadata_dtype"],
    extra_include_paths=['/usr/local/lib/python3.7/dist-packages/torch/include'],
    build_directory="/content/torch_cpp_extension"
)

device = xm.xla_device()

cast_after_init = True
# cast_after_init = False

# weight
if cast_after_init:
    weight = torch.ones(66, 44, device=device, requires_grad=True)
    print(inline.get_metadata_dtype(weight)) # 6 = float

    assert weight.is_leaf
    with torch.no_grad():
        weight.data = weight.data.to(torch.bfloat16)

    print(inline.get_metadata_dtype(weight)) # 15 = bfloat16 (if non-leaf, this is also 6)

else:
    weight = torch.ones(66, 44, device=device, requires_grad=True, dtype=torch.bfloat16)

# forward pass
x = torch.ones(22, 44, dtype=torch.bfloat16, device=device)
y = torch.nn.functional.linear(x, weight)
loss = y.sum()
xm.mark_step()

# backward pass
loss.backward()
print(f"cast_after_init: {cast_after_init}")
print("backward:", torch_xla._XLAC._get_xla_tensors_text([weight.grad]))
xm.mark_step()
hjm-aws commented 2 years ago

But since @ymwangg mentioned that the XLA can eliminate this cast, I guess this is not blocking you :)

@ronghanghu Ronghang, I can confirm that our compiler can get rid of the double cast.

baoleai commented 1 year ago

This also occurs in AMP scenarios, even after setting torch.__future__.set_overwrite_module_params_on_conversion(True) @ronghanghu for example

import torch
import torch_xla.core.xla_model as xm

from torch_xla.amp import autocast

torch.__future__.set_overwrite_module_params_on_conversion(True)

device = xm.xla_device()
batchsize = 22
inputsize = 44
outputsize = 66

# input
x = torch.ones(batchsize, inputsize, dtype=torch.float, device=device)

# module
linear = torch.nn.Linear(inputsize, outputsize, device=device, bias=False)

xm.mark_step()
with autocast(device, dtype=torch.bfloat16):
  y = linear(x)
  loss = y.sum()
xm.mark_step()
loss.backward()
xm.mark_step()

, the backward IR is:

IR {
  %0 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %1 = f32[] aten::expand(%0), xla_shape=f32[], size=(), dynamic_dims=()
  %2 = f32[22,66]{1,0} aten::expand(%1), xla_shape=f32[22,66]{1,0}, size=(22, 66)
  %3 = bf16[22,66]{1,0} xla::cast(%2), xla_shape=bf16[22,66]{1,0}, type=bf16, dtype=BFloat16, stype=Float
  %4 = bf16[22,44]{1,0} xla::device_data(), xla_shape=bf16[22,44]{1,0}, device=GPU:0
  %5 = bf16[44,22]{0,1} aten::permute(%4), xla_shape=bf16[44,22]{0,1}, dims=(1, 0)
  %6 = bf16[44,66]{1,0} aten::mm(%5, %3), xla_shape=bf16[44,66]{1,0}
  %7 = bf16[66,44]{0,1} aten::permute(%6), xla_shape=bf16[66,44]{0,1}, dims=(1, 0)
  %8 = f32[66,44]{0,1} xla::cast(%7), xla_shape=f32[66,44]{0,1}, type=f32, dtype=Float, stype=BFloat16, ROOT=0
}

There is an extra cast %8 in the IR.