apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.76k stars 3.47k forks source link

[Bug] Many imported Relay models fail to parse when saved in the text format #9572

Closed slyubomirsky closed 2 years ago

slyubomirsky commented 2 years ago

The naming conventions of many other deep learning frameworks do not play together nicely with the Relay text format. For example, many PyTorch models include dots in variable names, which are parsed as index operators in Relay. What is curious is that the imported programs (when they are referenced as the directly imported AST objects) will work initially, but will not work once written to the text format and are attempted to be parsed back (at which point they fail to type check).

For example, with the PyTorch importer (used in the unit tests):

import torch
import tvm
from tvm import relay

model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
model = model.eval()
src = torch.rand((10, 32, 256))
tgt = torch.rand((10, 32, 256))
trace = torch.jit.trace(model, [src, tgt])
trace = trace.float().eval().cpu()
input_shapes = [("src", (10, 32, 256)), ("tgt", (10, 32, 256))]
mod, params = relay.frontend.from_pytorch(trace, input_shapes, {})

# writes successfully
with open("transformer.relay", "w") as fp:
    fp.write(mod.astext())

# attempt to parse...
with open("transformer.relay", "r") as fin:
     tvm.parser.fromtext(fin.read())

The parsed model fails to read because many of the model parameters have dots in the name, e.g., %decoder.layers.0.self_attn.in_proj_weight. This yields the error expected a local variable found '.'

I see two approaches for addressing this:

  1. Sanitize variable names in the text format pretty-printer
  2. Sanitize variable names in the importers

In principle, the first approach would be good on the grounds that it would not require modifying all of the importers. However, importers typically provide both a modified Relay module and a parameter dictionary and the sanitizer would have to also sanitize the names in the parameter dictionary. So the importers themselves may make more sense as the place for sanitizing names.

I also think it would be a good idea to include text-format roundtripping in importer tests to ensure this will not be an issue in the future.

slyubomirsky commented 2 years ago

I checked and the PyTorch importer has a "parser-friendly" flag that replaces only dots, which is reasonable (maybe it should be on by default). I will see if this problem exists in other importers