pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 452 forks source link

[Fori Loop] Inconsistent Shape Behavior #7665

Open huzama opened 2 months ago

huzama commented 2 months ago

🐛 Bug Report

Description

Running the following code results in the error:

Status: INVALID_ARGUMENT: The parameter of condition and body, the result of the body, and init must all have the same shape; got Condition: (in: (s64[], s64[], s64[1], s64[1], s64[1])) -> pred[]; body: (in: (s64[], s64[1], s64[1], s64[1], s64[1])) -> (s64[], s64[], s64[1], s64[1], s64[1]); init: (s64[], s64[], s64[1], s64[1], s64[1])..

However, changing the operation inside body_fn seems to resolve the issue.

Code to Reproduce

import torch
import torch_xla
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm

class TPUComputation:
    def __init__(self):
        self.device = xm.xla_device()
        self.init_x = torch.tensor([1], device=self.device)
        self.init_y = torch.tensor([1], device=self.device)
        self.init_z = torch.tensor([1], device=self.device)
        self.iteri = torch.tensor(10, device=self.device)
        self.quantity = torch.tensor(0, device=self.device)

    def cond_fn(self, iteri, x, y, z, q=None):
        return iteri > 0

    def body_fn(self, iteri, x, y, z, q=None):
        return iteri - 1, x.clone(), y.add(1), z + self.quantity  # Problemmatic Line

    def compute(self):
        result = while_loop(
            self.cond_fn,
            self.body_fn,
            (self.iteri, self.init_x, self.init_y, self.init_z),
        )
        return result

if __name__ == "__main__":
    computation = TPUComputation()
    result = computation.compute()
    print(result)

Operations that Work

The following operations do not produce the error:

  1. iteri - 1, x.clone(), y.add(1), z + 1
  2. iteri - 1, x.clone(), y.add(1), z + self.quantity when self.quantity = torch.tensor([0], device=self.device)

Environment

JackCaoG commented 2 months ago

@ManfeiBai can you take a look?

ManfeiBai commented 1 month ago

Hi, @huzama, thanks for pointing this

yes, I have reproduced the same error with your instruction locally too:

tried again with more try:

based on above link's log's xlacomputation for body and cond, we notice that xlacomputation generation would missed/ignore tensor(0) and tensor(1), and try to create a new constant locally to use


for how to print xlacomputation(), not sure if anyone need this, but post it here and will add in while_loop document later: adding these code after line1 and line2:

# after line1
  body_hlo_print = xb.get_computation_hlo(body_computation)
  print("body computation: !!!!!!!!!")
  print(body_hlo_print)

# after line2
  cond_hlo_print = xb.get_computation_hlo(cond_computation)
  print("cond computation: !!!!!!!!!")
  print(cond_hlo_print)