Open aws-rhsoln opened 1 week ago
@ManfeiBai @tengyifei
This way the xla compiler won't know the trip count unless it evaluates the parameter during compilation.
Correct. Wondering do you have a specific motivation to teach XLA the trip count? Even if one input to the While
op is a constant, the cond
computation still has to compare it with another constant to determine if it's going to break the loop.
I have tried other ways to make the condition input a constant, however, that input gets optimized away
Could you share how did you produce the HLO in the second section? Thanks.
I want the ability to unroll the while loop, which requires the xla to know the trip count.
I generated the constant by using xla_builder and passed that as input to _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), ())
. Here iteri
is a constant.
I want the ability to unroll the while loop, which requires the xla to know the trip count.
I generated the constant by using xla_builder and passed that as input to
_xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), ())
. Hereiteri
is a constant.
do you want to debug code/HLO when the iteri
get to a specific value? such as for:
def cond_fn(iteri, x):
return iteri > 0
def body_fn(iteri, x):
return iteri - 1, torch.add(x, 1)
init_val = torch.tensor(3, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
you might want to debug when iteri
==5 ?
if in that case, I would suggest to use _xla_while_loop_wrapper
directly, like added test case in https://github.com/pytorch/xla/pull/7993, and you could do some debug in code, for HLO, I would say HLO for each iteri/trip should be the same once they wrapped into while_loop, so the HLO we get before would be the HLO used by each iteri/trip
_xla_while_loop_wrapper
will skip some constraints, and we could catch a specific iteri value during the loops;
if you want the loop only run once with iteri
==5, we might want to limit iteri
and init_val
to 5 like:
def cond_fn(iteri, x):
# if iteri==5:
# print("cond_fn: iteri is 5 now !!!")
# print("iteri: ", iteri)
return iteri > 5
def body_fn(iteri, x):
# if iteri==5:
# print("body_fn: iteri is 5 now !!!")
# print("iteri: ", iteri)
return iteri - 1, torch.add(x, 1)
init_val = torch.tensor(5, dtype=torch.int32, device=device)
iteri = torch.tensor(5, device=device)
_, res_with_loop = _xla_while_loop_wrapper(cond_fn, body_fn, (iteri, init_val), additional_inputs=())
this above test case will only run once with iteri
==5
Question:
- Should the condition input be a constant?
for while operator test: https://github.com/pytorch/xla/blob/master/test/test_while_loop.py#L28
:
def test_while_loop_addition(self):
device = xm.xla_device()
def cond_fn(iteri, x):
return iteri > 0
def body_fn(iteri, x):
return iteri - 1, torch.add(x, 1)
init_val = torch.tensor(3, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
cond input has iteri
and x
: iteri
is the iteriator used as trip count, it get initialized and get changed in each trip; x
is the carry value to be executed in each trip;
due to iteri
and x
need to be changed in each trip, so do we have more context for requirement for the condition input be a constant
, any use cases?
or you want to try _xla_while_loop_wrapper
like the above comment?
- I see some constraints for the while op, plan to make it workable for different values of inputs?
yes, there are some constraints for the while_loop(XLA::While), such as:
these constraints are limited due to XLA:While
op limitation to run on TPU;
for make it workable for different values of inputs
, do we have use/test case for example?
different values of inputs
means different value of iteri, x
, we would add more use case such as mnist model, and more complex model in the future, its still WIP;different values of inputs
means use different iteri, x
in cond_fn's arg, body_fn's arg&return, while_loop's inputs, this would break the restriction of cond_fn's input, body_fn's input and return, while_loops' input should be the same size and shape
mentioned above, this will break XLA::While
's prerequirement, and that would be a XLA op question;So for the while loop unroller pass in openxla, the condition input to the while loop needs to be a constant for the unroller pass to determine the trip count.
Can you expand a bit more on the mnist model example? Are we planning to do grad accumulation using this while loop ?
When running the while operator test: https://github.com/pytorch/xla/blob/master/test/test_while_loop.py#L28
I see an HLO that looks like the following (this is the un-optimized graph):
In this case the input tuple has a parameter. This way the xla compiler won't know the trip count unless it evaluates the parameter during compilation. Shouldn't this be a constant?
I have tried other ways to make the condition input a constant, however, that input gets optimized away and I end up getting an HLO as follows:
cond_hlo:
Question: