pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
81.61k stars 21.89k forks source link

[dynamo/guards] Symbolic shape guard does not fail, thus failing to recompile when shape changes #113875

Closed jon-chuang closed 9 months ago

jon-chuang commented 9 months ago

šŸ› Describe the bug

Dynamic shape guard is not failed in the following

import torch
def fn(x):
    if x.size() != (5, 1, 2, 3):
        return x.cos()
    return x.sin()

opt_fn = torch.compile(fn, backend="eager", dynamic=True)

x = torch.ones(5, 1, 3, 4)
x2 = torch.ones(5, 1, 2, 3)
torch.testing.assert_close(fn(x), opt_fn(x))
# Installs guard: ~(Eq(L['x'].size()[0], 5) & Eq(L['x'].size()[2], 2) & Eq(L['x'].size()[3], 3))
torch.testing.assert_close(fn(x2), opt_fn(x2))
# ~(Eq(L['x'].size()[0], 5) & Eq(L['x'].size()[2], 2) & Eq(L['x'].size()[3], 3)) should fail, but doesn't!

Flipping the order of x, x2 passes. With recompile:

torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function fn in /home/jonch/Desktop/Programming/mlsys/pytorch/test/dynamo/test_repros.py:3578
[2023-11-16 11:47:42,965] torch._dynamo.guards.__recompiles: [DEBUG]     triggered by the following guard failure(s):
[2023-11-16 11:47:42,965] torch._dynamo.guards.__recompiles: [DEBUG]     - Eq(L['x'].size()[0], 5) & Eq(L['x'].size()[2], 2) & Eq(L['x'].size()[3], 3)

Versions

main

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519

jon-chuang commented 9 months ago

@ezyang not sure why this was closed? I think it is still a bug

The PR linked is where the bug was identified.

ezyang commented 9 months ago

Oh, I thought that PR was fixing this bug. Have you root caused this one then?

jon-chuang commented 9 months ago

Nope not yet

ezyang commented 9 months ago

It's because we messed up codegen for negation

 ~(Eq(L['x'].size()[0], 5) & Eq(L['x'].size()[2], 2) & Eq(L['x'].size()[3], 3))

It's wrong to use bitwise negation here, it will do the wrong thing