pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.87k stars 22.61k forks source link

IRSimplifier performs unsafe transformation on mod: m*n%m -> 0 #53442

Open huiguoo opened 3 years ago

huiguoo commented 3 years ago

🐛 Bug

The transformation in IRSimplifier for integers: m*n % m -> 0 is not safe because m is not guaranteed to be non-zeros. We need to forbid this optimization.

Chillee commented 3 years ago

Isn't n % 0 undefined? When is it possible to have this kind of expression?

ZolotukhinM commented 3 years ago

Isn't n % 0 undefined? When is it possible to have this kind of expression?

In C it is undefined, in pytorch it probably is defined (i.e. it probably either means NaN or throws an exception). A user could construct such an expression manually, e.g.

VarHandle x("x", kInt);
Tensor *A = Compute("A", {{100, "i"}},  [&](const VarHandle& i) { return i*x%x; });

Another thing worth mentioning is that we're not simplifying m*n/n to m, so at least we'd better be consistent at what we do there.

bertmaher commented 3 years ago

It might be worth noting the super-annoying fact that this transformation in general doesn't match pytorch semantics:

>>> x = torch.ones((1,1), dtype=torch.int32) * 0x12345
>>> y = torch.ones((1,1), dtype=torch.int32) * 0x54321
>>> x * y % y
tensor([[233349]], dtype=torch.int32)
bertmaher commented 3 years ago

Although I kind of think that might be undefined behavior in C++ (signed integer overflow). But what I find surprising is that c++ compilers seem to not do the m * n % n -> 0 optimization either, which makes me a bit worried that it's actually invalid, not just UB: https://gcc.godbolt.org/z/nYKrr5

bertmaher commented 3 years ago

I think this optimization is actually legal in general. I consulted with our LLVM team, and this type of optimization is done for division (see https://github.com/llvm/llvm-project/blob/main/llvm/lib/Analysis/InstructionSimplify.cpp#L1054), and it seems like an oversight to not do it for modulus.

It's possible that we should only do it when the range of m * n is known not to overflow -- although we probably want to special case/assert this for indexing math since a tensor with numel>2^64 is probably not a supported thing anyways, even if you can technically declare it (I hope).

ZolotukhinM commented 3 years ago

I'm not sure we could use LLVM (and C/C++ in general) as a guidance here, since discrepancy from pytorch eager would probably be considered a bug.

Here is the behavior we need to replicate I think:

>>> x = torch.ones((1,1), dtype=torch.int32) * 7
>>> y = torch.zeros((1,1), dtype=torch.int32)
>>> x * y / y
tensor([[nan]])
>>> x * y % y
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: ZeroDivisionError
>>> x * y // y
/home/mvz/pytorch/torch/tensor.py:565: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:341.)
  return torch.floor_divide(self, other)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/mvz/pytorch/torch/tensor.py", line 29, in wrapped
    return f(*args, **kwargs)
  File "/home/mvz/pytorch/torch/tensor.py", line 565, in __floordiv__
    return torch.floor_divide(self, other)
RuntimeError: ZeroDivisionError
bertmaher commented 3 years ago

Heh, so I banned aten::fmod on integers from fusion groups for exactly this reason. It's really troublesome to preserve exceptions from edge cases like this in generated code.

  1. I don't know if this is actually feasible, but maybe there should be a separate set of rules for user vs internal computations. Like if we generate m * n % n as an indexing expression it's fine for us to simplify it to 0, but if the user has typed that in python, we should honor it.
  2. Not exactly this question, but I feel uneasy about trying to replicate any eager semantics that are the result of UB. Presumably most things just work out (our int32s will overflow like eager's) but it feels like a potentially nasty path.