pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.25k stars 22.45k forks source link

Dynamo reordering positional arguments #133116

Closed jakeharmon8 closed 2 months ago

jakeharmon8 commented 2 months ago

🐛 Describe the bug

Dynamo in PyTorch 2.4.0 is tracing simple functions differently from PyTorch 2.3.1:

def debug_backend(gm, inputs):
  print(gm)
  return gm

def fn(x, y):
  x = torch.matmul(x, y)
  return x

torch._dynamo.optimize(debug_backend)(fn)(torch.ones(2, 4), torch.ones(4, 8))

Output in 2.3.1:

def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
    l_x_ = L_x_
    l_y_ = L_y_
    x = torch.matmul(l_x_, l_y_);  l_x_ = l_y_ = None
    return (x,)

Output in 2.4.0:

def forward(self, L_y_ : torch.Tensor, L_x_ : torch.Tensor):
    l_y_ = L_y_
    l_x_ = L_x_
    x = torch.matmul(l_x_, l_y_);  l_x_ = l_y_ = None
    return (x,)

The reordering of the arguments has made 2.4.0 not backwards compatible with our custom backend. Is this intentional? And if so, how do other backends (e.g. inductor) handle positional arguments being reordered?

Error logs

n/a

Minified repro

import torch

if __name__ == "__main__":
  def debug_backend(gm, inputs):
    print(gm)
    return gm

  def fn(x, y):
    x = torch.matmul(x, y)
    return x

  torch._dynamo.optimize(debug_backend)(fn)(torch.ones(2, 4), torch.ones(4, 8))

Versions

Collecting environment information...

Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 48 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 128 On-line CPU(s) list: 0-127 Vendor ID: AuthenticAMD Model name: AMD EPYC 7B13 CPU family: 25 Model: 1 Thread(s) per core: 2 Core(s) per socket: 32 Socket(s): 2 Stepping: 0 BogoMIPS: 4899.99 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip vaes vpclmulqdq rdpid fsrm Hypervisor vendor: KVM Virtualization type: full L1d cache: 2 MiB (64 instances) L1i cache: 2 MiB (64 instances) L2 cache: 32 MiB (64 instances) L3 cache: 256 MiB (8 instances) NUMA node(s): 2 NUMA node0 CPU(s): 0-31,64-95 NUMA node1 CPU(s): 32-63,96-127 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Reg file data sampling: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] torch==2.4.0 [pip3] triton==3.0.0 [conda] Could not collect

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

shunting314 commented 2 months ago

related to nn module inlining? @anijain2305

anijain2305 commented 2 months ago

This is not related inlining.

I am not sure what caused this changed. But in general, Dynamo FX graph does not have BC surface area when it comes to reordering of inputs.

Backends like inductor do not need the ordering of inputs of the FX graph to be maintained. As far as the matmul goes, it still gets the right set of inputs.

Cc @mlazos as I might be slow In responding for next week