pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.84k stars 22.61k forks source link

[export] aten._weight_norm_interface doesn't have decomposition #112086

Closed qihqi closed 1 year ago

qihqi commented 1 year ago

🐛 Describe the bug

The op aten._weight_norm_interface is not a core Aten op, and it doesn't have a decomposition.

Repro:

import torch

args = (torch.randn(768, 48, 128), torch.randn(1, 1, 128))

def func(x, y):
  return torch.ops.aten._weight_norm_interface(x, y, 2)

func(args[0], args[1])
exp = torch.export.export(func, args)
exp = exp.run_decompositions()
print(exp.graph_module.code)

Output:

def forward(self, arg0_1, arg1_1):
    _weight_norm_interface = torch.ops.aten._weight_norm_interface.default(arg0_1, arg1_1, 2);  arg0_1 = arg1_1 = None
    getitem = _weight_norm_interface[0]
    getitem_1 = _weight_norm_interface[1];  _weight_norm_interface = None
    return (getitem, getitem_1)

Expected output: I expect aten._weight_norm_interface to be decomposed to other ops that are in the core Aten set.

Versions

NUMA node1 CPU(s): 28-55,84-111 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] flake8==6.0.0 [pip3] flake8-bugbear==23.3.23 [pip3] flake8-comprehensions==3.12.0 [pip3] flake8-executable==2.1.3 [pip3] flake8-logging-format==0.9.0 [pip3] flake8-pyi==23.3.1 [pip3] flake8-simplify==0.19.3 [pip3] mypy==1.4.1 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.25.0 [pip3] onnx==1.13.1 [pip3] onnxruntime==1.15.1 [pip3] torch==2.2.0a0+gite268139 [pip3] torch-xla==2.2.0 [pip3] torchaudio==2.0.2 [pip3] torchdata==0.7.0 [pip3] torchtext==0.16.0 [pip3] torchvision==0.16.0a0+463cdea [pip3] triton==2.1.0 [conda] blas 1.0 mkl [conda] mkl 2023.1.0 h6d00ec8_46342 [conda] mkl-include 2023.2.0 pypi_0 pypi [conda] mkl-service 2.4.0 py310h5eee18b_1 [conda] mkl_fft 1.3.6 py310h1128e8f_1 [conda] mkl_random 1.2.2 py310h1128e8f_1 [conda] numpy 1.25.2 pypi_0 pypi [conda] numpy-base 1.25.0 py310hb5e798b_0 [conda] torch 2.2.0a0+gite268139 dev_0 [conda] torch-xla 2.2.0 dev_0 [conda] torchaudio 2.0.2 pypi_0 pypi [conda] torchdata 0.7.0 pypi_0 pypi [conda] torchtext 0.16.0 pypi_0 pypi [conda] torchvision 0.16.0a0+463cdea pypi_0 pypi [conda] triton 2.1.0 pypi_0 pypi

cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo

qihqi commented 1 year ago

I'd be happy to contribute a decomposition; but can anyone tell me the math formula for it?

I found this: https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html

but aten._weight_norm_interface returns 2 tensors, what is the other one?

tugsbayasgalan commented 1 year ago

cc: @SS-JIA @angelayi