kaiyuyue / nxtp

Object Recognition as Next Token Prediction (CVPR 2024 Highlight)
https://arxiv.org/abs/2312.02142
Other
160 stars 6 forks source link

Export `LangClassifier` to onnx #2

Closed evdcush closed 7 months ago

evdcush commented 8 months ago

Thanks for providing code and checkpoints for your method! 🙇🏻‍♂️

The method is interesting, but the paper lacks any basic diagram or graphic explaining the model architecture at a high level.

In order to get a coarse understanding of the model architecture, I thought it would be helpful to export the model to onnx format, and use Netron to visualize the arch.

It was a decently involved process to get nxtp/src/infer.py runnable, but after a good amount of effort I got it running, and attempted to torch.onnx.export the model.

So, I ran:

torch.onnx.export(model, img, 'nxtp_g3m_model.onnx')

But it looks like LangClassifier has a dummy forward, which I've not seen before in nn.Module implementations. Maybe a regular python class would have been more suitable, but at any rate, I was still hoping to export the outer-most meta architecture for your method.

May I request you provide an onnx file for nxtp, or maybe document how to export the model to onnx?

kaiyuyue commented 8 months ago

Hi @evdcush

I saw your issue. The LangClassifier has a dummy forward due to its long forward procedure (encoding images, looping next token prediction, and sampling labels). So I thought that simply putting them in the main forward flow could make code more clear.

The whole model has two parts:

In the repo, $f$ is a naive ViT from CLIP, and language decoder is a standard LLaMA-1/-2 but fewer transformer blocks by truncating/dropping the intermediate blocks.

So for exporting onnx, the possible easy way is to hub $f$ and $h$ to a single nn.Module class, including the model definition and forward function, and then export. I see there are some repos that have implemented onnx export for ViT and LLaMA-2.

Our language decoder has a non-causal attention mask (not a single lower triangle as causal one). This design might bring an issue in onnx export. But I suggest just to ignore this design and use the original causal masking in onnx forward for inference. The performance won't drop too much (Table 5 in the paper) on the product aspect.

So yeah, I can try to export an onnx file on my side. If having some progress, will keep you updated. Thanks!

kaiyuyue commented 8 months ago

Hi, please see docs/README.md#onnx-export.

evdcush commented 7 months ago

NICE! Exactly what I was looking for! 🤓 Thank you @kaiyuyue! 🙇🏻‍♂️