Open huiguoo opened 3 years ago
Isn't n % 0
undefined? When is it possible to have this kind of expression?
Isn't
n % 0
undefined? When is it possible to have this kind of expression?
In C it is undefined, in pytorch it probably is defined (i.e. it probably either means NaN or throws an exception). A user could construct such an expression manually, e.g.
VarHandle x("x", kInt);
Tensor *A = Compute("A", {{100, "i"}}, [&](const VarHandle& i) { return i*x%x; });
Another thing worth mentioning is that we're not simplifying m*n/n
to m
, so at least we'd better be consistent at what we do there.
It might be worth noting the super-annoying fact that this transformation in general doesn't match pytorch semantics:
>>> x = torch.ones((1,1), dtype=torch.int32) * 0x12345
>>> y = torch.ones((1,1), dtype=torch.int32) * 0x54321
>>> x * y % y
tensor([[233349]], dtype=torch.int32)
Although I kind of think that might be undefined behavior in C++ (signed integer overflow). But what I find surprising is that c++ compilers seem to not do the m * n % n -> 0
optimization either, which makes me a bit worried that it's actually invalid, not just UB: https://gcc.godbolt.org/z/nYKrr5
I think this optimization is actually legal in general. I consulted with our LLVM team, and this type of optimization is done for division (see https://github.com/llvm/llvm-project/blob/main/llvm/lib/Analysis/InstructionSimplify.cpp#L1054), and it seems like an oversight to not do it for modulus.
It's possible that we should only do it when the range of m * n
is known not to overflow -- although we probably want to special case/assert this for indexing math since a tensor with numel>2^64 is probably not a supported thing anyways, even if you can technically declare it (I hope).
I'm not sure we could use LLVM (and C/C++ in general) as a guidance here, since discrepancy from pytorch eager would probably be considered a bug.
Here is the behavior we need to replicate I think:
>>> x = torch.ones((1,1), dtype=torch.int32) * 7
>>> y = torch.zeros((1,1), dtype=torch.int32)
>>> x * y / y
tensor([[nan]])
>>> x * y % y
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: ZeroDivisionError
>>> x * y // y
/home/mvz/pytorch/torch/tensor.py:565: UserWarning: floor_divide is deprecated, and will be removed in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values.
To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at ../aten/src/ATen/native/BinaryOps.cpp:341.)
return torch.floor_divide(self, other)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/mvz/pytorch/torch/tensor.py", line 29, in wrapped
return f(*args, **kwargs)
File "/home/mvz/pytorch/torch/tensor.py", line 565, in __floordiv__
return torch.floor_divide(self, other)
RuntimeError: ZeroDivisionError
Heh, so I banned aten::fmod
on integers from fusion groups for exactly this reason. It's really troublesome to preserve exceptions from edge cases like this in generated code.
m * n % n
as an indexing expression it's fine for us to simplify it to 0, but if the user has typed that in python, we should honor it.
🐛 Bug
The transformation in IRSimplifier for integers:
m*n % m -> 0
is not safe becausem
is not guaranteed to be non-zeros. We need to forbid this optimization.