pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 124 forks source link

[inductor] `maximum` op does not respect `nans` #1572

Closed anijain2305 closed 1 year ago

anijain2305 commented 2 years ago

Repro

import torch
from torch import tensor, device
import torch.fx as fx
from torchdynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx

# torch version: 1.14.0a0+git65b4080
# torch cuda version: 11.6
# torch git version: 65b408074f4ecc99faf5720ea5b3570a483ec9f4

# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Thu_Feb_10_18:23:41_PST_2022
# Cuda compilation tools, release 11.6, V11.6.112
# Build cuda_11.6.r11.6/compiler.30978841_0

# GPU Hardware Info:
# NVIDIA A100-SXM4-40GB : 8

from torch.nn import *
class Repro(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, add_283, lift_fresh_copy_180):
        maximum_90 = torch.ops.aten.maximum.default(add_283, lift_fresh_copy_180)
        return (maximum_90, )

x = torch.tensor([-1.0, 0, 1.0, -2.0], device="cuda")
x = torch.sqrt(x)
add = torch.randn(4, device="cuda")

args = [x, add]
mod = make_fx(Repro().to(device="cuda"))(*args)

from torchinductor.compile_fx import compile_fx_inner
from torchdynamo.debug_utils import same_two_models

compiled = compile_fx_inner(mod, args)
ref = mod(*args)
res = compiled(*args)
assert same_two_models(mod, compiled, args, only_fwd=True), "Accuracy failed"

cc @desertfire @SherlockNoMad

Helps with fbnet and mobilenetv3_100

ngimel commented 2 years ago

Duplicate of pytorch/pytorch#93784