Closed anijain2305 closed 2 years ago
@voznesenskym Assigning to you arbitrarily. Let me know if thats ok.
@yanboliang Assigning to you. on first glimpse, this seems to be the decomp issue. Can you please take a look.
I have identified the root cause, actually this is a PyTorch core bug.
The failure is caused by this line, because torch.index_put_
does not support slice(None)
s as indices. I also verified by running PT unit tests, as all index_add_
ops tests are using dim = 0
where this problem can't be exposed.
Inductor removed lowering for index_select
and added decomp for index_{add,add_}
at https://github.com/pytorch/torchdynamo/pull/1292, I tried to revert that change and found it will fix this bug.
I think if we can easily support slice(None)
s, we should fix this inside of PyTorch; otherwise, remove decomp for these ops?
Thanks @SherlockNoMad for helping me navigate and find the root cause. cc @jansel @ngimel @lezcano Any suggestion?
I believe changing that line from slice(None)
to None
should fix the issue. Could you confirm?
I put up a fix: https://github.com/pytorch/pytorch/pull/86266
Thanks @lezcano for the fix, let's not remove the decomps and enable tests when pytorch pin is updated. @yanboliang can you please add tests to inductor that would expose this?
@lezcano Thanks for your PR, I verified it fixed my issue! @ngimel I'll add tests to inductor when pytorch pin is updated.
benchmarks/huggingface.py --training -dcuda --accuracy --training --inductor --only=XLNetLMHeadModel
Error
Repro