Open XuehaiPan opened 3 weeks ago
Affected PRs:
I found we are trying to inline functions while not checking if they return infinite generators:
BaseUserFunctionVariable.call_function
-> InstructionTranslatorBase.inline_user_function_return
-> InliningInstructionTranslator.inline_call
-> InliningInstructionTranslator.inline_call_
-> tracer.run(); ListIteratorVariable(tracer.generated_items, mutable_local=MutableLocal())
where we will fill the items from the iterator to tracer.generated_items
and build a ListIteratorVariable
.
In tracer.run()
, we will loop over InliningGeneratorInstructionTranslator.YIELD_FROM
and push item into generated_items
:
this results in an endless loop.
Infinite iterators are not inline-able, such as:
itertools.count()
itertools.repeat(obj, None)
while iterating on them will not causing infinity loop in eager mode:
for n in itertools.count():
# do something
if condition:
break
for i, j in zip(range(256), itertools.repeat(obj, None)):
# do something
We need a way to delay the inline process:
list(itertools.count())
is not inline-able: infinite elements.list(zip(range(10, itertools.count())))
is inline-able: 10 constant elements.If we really care about this, the right thing is to just properly support generators
An iterator is a special generator that always send None
: gen.send(None)
.
I think the most viable solution is to support callable iterator.
it = iter(callable, sentinel)
See also:
Dynamo's
zip
does not support infinite iterators (e.g.,itertools.count()
).Dynamo always realizes iterable into list items, which leads to an infinite loop. Also, fetching items from an iterator may have side effects. We should not realize the iterator into the sequence at once.
_Originally posted by @XuehaiPan in https://github.com/pytorch/pytorch/pull/133876#discussion_r1722156380_
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames