Mrmoore98 / VectorMapNet_code

This is the official code base of VectorMapNet (ICML 2023)
https://tsinghua-mars-lab.github.io/vectormapnet/
GNU General Public License v3.0
405 stars 56 forks source link

Export model to ONNX #22

Open Fabian-Ket opened 1 year ago

Fabian-Ket commented 1 year ago

Hello I tried to export VectorMapNet loaded with the given checkpoint to ONNX, but I can't figure up how to pass the input to the export call:

# model is init like in tools/test.py
model = build_model(cfg.model, test_cfg=cfg.get('test_cfg'))
...

mm = MMDataParallel(model, device_ids=[0])

dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
        dataset,
        samples_per_gpu=1,
        workers_per_gpu=1,
        dist=False,
        shuffle=False)

for i, data in enumerate(data_loader):
    torch.onnx.export(mm.module, args=data, f='VectorMapNet.onnx')
    break

But this fails with: RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs. Dictionaries and strings are also accepted, but their usage is not recommended. Here, received an input of unsupported type: DataContainer I don't exactly know what are the input and output tensors of the model.

I also tried to use:

from torchviz import make_dot

for i, data in enumerate(data_loader):
        with torch.no_grad():
            yhat = mm(return_loss=False, rescale=True, **data)
            break

make_dot(yhat, params=dict(list(mm.module.named_parameters()))).render("VectorMapNet", format="png")

to just plot the model, to get a better insight, but this fails with TypeError: unhashable type: 'list' I think because make_dot() can't handle the post processed yhat prediction from the mm() call.

So how could you export the VectorMapNet to ONNX? I'm really new to torch.