pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 462 forks source link

While operator test generates condition input as a parameter instead of a constant #7986

Open aws-rhsoln opened 1 week ago

aws-rhsoln commented 1 week ago

When running the while operator test: https://github.com/pytorch/xla/blob/master/test/test_while_loop.py#L28

I see an HLO that looks like the following (this is the un-optimized graph):

while_loop.40 {
  p0.41 = s64[] parameter(0)
  p1.42 = s32[] parameter(1)
  tuple.43 = (s64[], s32[]) tuple(p0.41, p1.42)
  ROOT while.44 = (s64[], s32[]) while(tuple.43), condition=PyLoweringContext.5.35, body=PyLoweringContext.12.25
}

ENTRY SyncTensorsGraph.49 {
  p0.3 = s64[] parameter(0), sharding={replicated}
  constant.2 = s64[] constant(1)
  constant.1 = s64[] constant(1)
  multiply.4 = s64[] multiply(constant.2, constant.1)
  subtract.5 = s64[] subtract(p0.3, multiply.4)
  p1.8 = s32[] parameter(1), sharding={replicated}
  constant.7 = s32[] constant(1)
  constant.6 = s32[] constant(1)
  multiply.9 = s32[] multiply(constant.7, constant.6)
  add.10 = s32[] add(p1.8, multiply.9)
  constant.11 = s64[] constant(0)
  compare.12 = pred[] compare(p0.3, constant.11), direction=GT
  call.45 = (s64[], s32[]) call(p0.3, p1.8), to_apply=while_loop.40
  get-tuple-element.46 = s64[] get-tuple-element(call.45), index=0
  get-tuple-element.47 = s32[] get-tuple-element(call.45), index=1
  ROOT tuple.48 = (s64[], s32[], pred[], s64[], s32[]) tuple(subtract.5, add.10, compare.12, get-tuple-element.46, get-tuple-element.47)
} // SyncTensorsGraph.49

In this case the input tuple has a parameter. This way the xla compiler won't know the trip count unless it evaluates the parameter during compilation. Shouldn't this be a constant?

I have tried other ways to make the condition input a constant, however, that input gets optimized away and I end up getting an HLO as follows:

cond_hlo:

%PyLoweringContext.10 (p0.11: f32[2,2], UnusedArgumentsPlaceholder.15: f32[2,2]) -> pred[] {
  %constant.13 = s64[] constant(0), metadata={op_type="prim__Constant" op_name="prim__Constant" source_file="/home/ubuntu/while_loop.py" source_line=80}
  %p0.11 = f32[2,2]{1,0} parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/home/ubuntu/while_loop.py" source_line=80}
  %call.12 = s64[] call(f32[2,2]{1,0} %p0.11), to_apply=%return_constReturnConst.7, metadata={op_type="xla___op_return_constReturnConst" op_name="xla___op_return_constReturnConst" source_file="/home/ubuntu/pt24/lib/python3.10/site-packages/torch_xla/core/xla_op_registry.py" source_line=44}
  ROOT %compare.14 = pred[] compare(s64[] %constant.13, s64[] %call.12), direction=LT, metadata={op_type="aten__lt" op_name="aten__lt" source_file="/home/ubuntu/while_loop.py" source_line=75}
  %UnusedArgumentsPlaceholder.15 = f32[2,2]{1,0} parameter(1)
}

ENTRY %PyLoweringContext.12.17 (in.1: (f32[2,2], f32[2,2])) -> pred[] {
  %in.1 = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.2 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=0
  %get-tuple-element.3 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %in.1), index=1
  ROOT %call.16 = pred[] call(f32[2,2]{1,0} %get-tuple-element.2, f32[2,2]{1,0} %get-tuple-element.3), to_apply=%PyLoweringContext.10
}

Question:

  1. Should the condition input be a constant?
  2. I see some constraints for the while op, plan to make it workable for different values of inputs?
JackCaoG commented 1 week ago

@ManfeiBai @tengyifei

tengyifei commented 1 week ago

This way the xla compiler won't know the trip count unless it evaluates the parameter during compilation.

Correct. Wondering do you have a specific motivation to teach XLA the trip count? Even if one input to the While op is a constant, the cond computation still has to compare it with another constant to determine if it's going to break the loop.

I have tried other ways to make the condition input a constant, however, that input gets optimized away

Could you share how did you produce the HLO in the second section? Thanks.

aws-rhsoln commented 1 week ago

I want the ability to unroll the while loop, which requires the xla to know the trip count.

I generated the constant by using xla_builder and passed that as input to _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), ()) . Here iteri is a constant.

ManfeiBai commented 1 week ago

I want the ability to unroll the while loop, which requires the xla to know the trip count.

I generated the constant by using xla_builder and passed that as input to _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), ()) . Here iteri is a constant.

do you want to debug code/HLO when the iteri get to a specific value? such as for:

    def cond_fn(iteri, x):
      return iteri > 0

    def body_fn(iteri, x):
      return iteri - 1, torch.add(x, 1)

    init_val = torch.tensor(3, dtype=torch.int32, device=device)
    iteri = torch.tensor(10, device=device)
    _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))

you might want to debug when iteri==5 ?

if in that case, I would suggest to use _xla_while_loop_wrapper directly, like added test case in https://github.com/pytorch/xla/pull/7993, and you could do some debug in code, for HLO, I would say HLO for each iteri/trip should be the same once they wrapped into while_loop, so the HLO we get before would be the HLO used by each iteri/trip

_xla_while_loop_wrapper will skip some constraints, and we could catch a specific iteri value during the loops;

if you want the loop only run once with iteri==5, we might want to limit iteri and init_val to 5 like:

    def cond_fn(iteri, x):
      # if iteri==5:
      #   print("cond_fn: iteri is 5 now !!!")
      # print("iteri: ", iteri)
      return iteri > 5

    def body_fn(iteri, x):
      # if iteri==5:
      #   print("body_fn: iteri is 5 now !!!")
      # print("iteri: ", iteri)
      return iteri - 1, torch.add(x, 1)

    init_val = torch.tensor(5, dtype=torch.int32, device=device)
    iteri = torch.tensor(5, device=device)
    _, res_with_loop = _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), additional_inputs=())

this above test case will only run once with iteri==5

ManfeiBai commented 1 week ago

Question:

  1. Should the condition input be a constant?

for while operator test: https://github.com/pytorch/xla/blob/master/test/test_while_loop.py#L28:

def test_while_loop_addition(self):
    device = xm.xla_device()

    def cond_fn(iteri, x):
      return iteri > 0

    def body_fn(iteri, x):
      return iteri - 1, torch.add(x, 1)

    init_val = torch.tensor(3, dtype=torch.int32, device=device)
    iteri = torch.tensor(10, device=device)
    _, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
    _, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
    self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))

cond input has iteri and x: iteri is the iteriator used as trip count, it get initialized and get changed in each trip; x is the carry value to be executed in each trip;

due to iteri and x need to be changed in each trip, so do we have more context for requirement for the condition input be a constant, any use cases?

or you want to try _xla_while_loop_wrapper like the above comment?

  1. I see some constraints for the while op, plan to make it workable for different values of inputs?

yes, there are some constraints for the while_loop(XLA::While), such as:

these constraints are limited due to XLA:While op limitation to run on TPU;

for make it workable for different values of inputs, do we have use/test case for example?

aws-rhsoln commented 1 week ago

So for the while loop unroller pass in openxla, the condition input to the while loop needs to be a constant for the unroller pass to determine the trip count.

Can you expand a bit more on the mnist model example? Are we planning to do grad accumulation using this while loop ?