pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.26k stars 22.13k forks source link

Deduplicate the operands passed into torch.cond after dynamo tracing. #103530

Open tugsbayasgalan opened 1 year ago

tugsbayasgalan commented 1 year ago

🚀 The feature, motivation and pitch

Currently, we lift the free variables inside torch.cond branches as extra inputs to the branch graph. As a result, for simplicitly, we naively extend the torch.cond operands list with free lifted variables from each branch. For example, let's consider cond(pred, true_fn, false_fn, [x]) where true_fn has a, b, c as free variables and false_fn has a, b, d as free variables. Then, dynamo will rewrite it as cond(pred, true_fn, false_fn, [x, a, b, c, a, b, d]). Ideally, we should detect this and deduplicate the operands list.

Alternatives

No response

Additional context

No response

tugsbayasgalan commented 1 year ago

cc: @zou3519

vadimkantorov commented 1 year ago

torch.cond seems not documented currently: https://pytorch.org/docs/master/search.html?q=torch.cond&check_keywords=yes&area=default# @albanD should I file a separate issue for this?

looking to have some adaptive computation / ternary op / if-thens based on intermediate computation results (to save time and not do useless computation)

albanD commented 1 year ago

We can file an issue. But in general, I don't expect that any OSS user needs it atm. So that's definitely not the worst offender of missing docs :p