Open ezyang opened 5 months ago
_tensors_definitely_do_not_overlap generates a quadratic number of guards when the tensors in question have symbolic shapes.
_tensors_definitely_do_not_overlap
@bdhirsh has a mini repro:
import torch @torch.compile(dynamic=True) def f(*args): for a in args: a.add_(1) return args[0] x = torch.ones(1000) args = x.split(10) out = f(*args)
with TORCH_LOGS="+dynamic" this spews thousands of lines of guards
TORCH_LOGS="+dynamic"
The general proposed fix strategy from @voznesenskym was to make overlap specialization a first class citizen for tensor guards.
Roll up to https://github.com/pytorch/pytorch/issues/118213
This is the proximal cause of OOM/compile time regression in https://www.internalfb.com/intern/sevmanager/view/s/382123/ https://www.internalfb.com/intern/sevmanager/view/s/382616/
main
cc @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang
@bdhirsh I am removing myself from this. IIUC, from our 1:1 discussion you mentioned that moving the guards to C++ wont really help here. But let me know if you need any help.
🐛 Describe the bug
_tensors_definitely_do_not_overlap
generates a quadratic number of guards when the tensors in question have symbolic shapes.@bdhirsh has a mini repro:
with
TORCH_LOGS="+dynamic"
this spews thousands of lines of guardsThe general proposed fix strategy from @voznesenskym was to make overlap specialization a first class citizen for tensor guards.
Roll up to https://github.com/pytorch/pytorch/issues/118213
This is the proximal cause of OOM/compile time regression in https://www.internalfb.com/intern/sevmanager/view/s/382123/ https://www.internalfb.com/intern/sevmanager/view/s/382616/
Versions
main
cc @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang