apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.42k stars 3.4k forks source link

[Bug] [Relax] Cannot import mobilenet_v3 #17068

Closed mshr-h closed 3 weeks ago

mshr-h commented 1 month ago

Cannot import mobilenet_v3 because Hardswish and Hardsigmoid are not supported by Relax. I'll try to fix it.

TODOs

Expected behavior

mobilenet_v3_small and mobilenet_v3_small can be imported with from_fx.

Actual behavior

Got the below error message when I executed the repro.

$ python compile_mobilenet_v3.py 
Traceback (most recent call last):
  File "/home/ubuntu/data/sandbox/tvm_/relax_/mobilenet_v3/compile_mobilenet_v3.py", line 34, in <module>
    main()
  File "/home/ubuntu/data/sandbox/tvm_/relax_/mobilenet_v3/compile_mobilenet_v3.py", line 21, in main
    mod = from_fx(graph_model, [(inp.shape, "float32")])
  File "/home/ubuntu/data/sandbox/.dep/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1698, in from_fx
    return TorchFXImporter().from_fx(
  File "/home/ubuntu/data/sandbox/.dep/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1570, in from_fx
    type(module) in self.convert_map
AssertionError: Unsupported module type <class 'torch.nn.modules.activation.Hardswish'>
[20:07:07] /home/ubuntu/data/sandbox/.dep/tvm/src/relax/ir/block_builder.cc:66: Warning: BlockBuilder destroyed with remaining blocks!

Environment

OS: Ubuntu 22.04 LTS on WSL2 TVM: 0e622e140c2df8c5ab88e27ee4e90254cddb80ce PyTorch: 2.3.0 Torchvision: 0.18.0

Steps to reproduce

import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx
import torch
import torchvision

def main():
  model_name = "mobilenet_v3_small"  # mobilenet_v3_small or mobilenet_v3_large
  inp = torch.rand(8, 3, 224, 224)

  weights = torchvision.models.get_model_weights(model_name).DEFAULT
  model_pth = torchvision.models.get_model(model_name, weights=weights).eval()

  # PyTorch
  output_pth = model_pth(inp)

  # TVM
  graph_model = torch.fx.symbolic_trace(model_pth)
  with torch.no_grad():
    mod = from_fx(graph_model, [(inp.shape, "float32")])

  target = tvm.target.Target("llvm", host="llvm")
  mod = relax.transform.DecomposeOpsForInference()(mod)
  mod = relax.transform.LegalizeOps()(mod)
  ex = relax.build(mod, target)
  vm = relax.VirtualMachine(ex, tvm.cpu())
  output_tvm = torch.tensor(vm["main"](tvm.nd.array(inp.detach().numpy())).numpy())

  torch.testing.assert_close(output_pth, output_tvm, rtol=1e-5, atol=1e-5)

if __name__ == "__main__":
  main()

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

cc @junrushao

yongwww commented 1 month ago

The error message shows that operation Hardswish is not supported in the fx converter, would you mind sending a pr to add it in https://github.com/apache/tvm/blob/main/python/tvm/relax/frontend/torch/fx_translator.py?

mshr-h commented 1 month ago

Thank you for your comment! Yes, I'm implementing hardswish support on the torch frontend (fx_translater.py). Initial implementation works fine. After clearing up my code, I'll send the PR.