NeuralNetworkVerification / Marabou

Other
251 stars 87 forks source link

C++ ONNX parser performance issues #754

Open MatthewDaggitt opened 7 months ago

MatthewDaggitt commented 7 months ago

Follow up from #744 (see the example network).

The C++ parser is currently slower than the ONNX one because my crude implementation of multi-dimensional tensors is not as elegant as those in numpy. In particular the two offending operations are transpose and

https://github.com/NeuralNetworkVerification/Marabou/blob/571748e3058b792c8c2eaee1c7381ad8bc878e2d/src/input_parsers/TensorUtils.h#L80-L89

and the all the lookups and broadcasting done here:

https://github.com/NeuralNetworkVerification/Marabou/blob/571748e3058b792c8c2eaee1c7381ad8bc878e2d/src/input_parsers/OnnxParser.cpp#L1637-L1646

I'm a bit torn on how to address this. I can either continue to refine the tensor representation to incorporate strides which allows us to perform these operations much cheaper, or start to depend on an existing tensor library. For the latter case I'm not quite sure what I would use...

wu-haoze commented 7 months ago

Hi @MatthewDaggitt , I think we do want to fix this performance issue. If on the small network I sent, parsing is already taking 20+ seconds it's not a good sign. I would be in favor of depending on an existing Tensor library. Do you think it'd be helpful to go to Pytorch or Tensorflow? While this might be a overkill for parsing, it would lay a good foundation for our future agenda of incorporating gradient-based method.

We also need to look at the license of the library we choose. Maybe something to discuss next week.

MatthewDaggitt commented 7 months ago

I'm not sure whether Pytorch or Tensorflow would help. We need efficient access to array operations over arbitrary-dimension C++ arrays, but I suspect the overhead of setting up an entire Pytorch/Tensorflow neural network just to compute them won't be worth it. Maybe I've misunderstood their capabilities though?

I have found that mdspan (multi-dimensional span) was introduced in C++ 2023 which looks promising:

If we implement it ourselves this looks useful:

Otherwise a third party library might be:

wu-haoze commented 7 months ago

I was suggesting using Pytoch/Tensorflow just for Tensor computation. From pytorch's documentation, they support general operations over tensors independent of machine learning. Though I haven't looked closely at it..

https://pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html

MatthewDaggitt commented 7 months ago

Ah I didn't realise they supported pure tensor operations as well, although it makes perfect sense when you think about it. Having said that the C++ Pytorch documentation is very minimal (https://pytorch.org/cppdocs/api/namespace_at.html#namespace-at) but I guess it's workable with.

wu-haoze commented 7 months ago

@MatthewDaggitt , it'd be great if Pytorch is workable. That'd lay a great foundation for our future development work. Their license also seems quite permissive: https://github.com/pytorch/pytorch/blob/main/LICENSE

MatthewDaggitt commented 6 months ago

Hmm so Pytorch is now being included fine, but I'm running into the problem that it also defines a LOG macro which clashes with our LOG macro. Do we have any preferences in the project with how to deal with this problem?

wu-haoze commented 6 months ago

Hmm so Pytorch is now being included fine, but I'm running into the problem that it also defines a LOG macro which clashes with our LOG macro. Do we have any preferences in the project with how to deal with this problem?

I'm not sure we have a protocol for this.. The solution below seems sensible..? https://stackoverflow.com/questions/7109795/macro-definition-conflict

@guykatzz , do you have any suggestions about resolving conflict of macro definition?

wu-haoze commented 6 months ago

@MatthewDaggitt actually I just recall we have encountered similar issues, and undefining the clashing operations seem to be the solution.

https://github.com/NeuralNetworkVerification/Marabou/blob/8129640537d63deac485daaf0f2f1c09e247e928/src/engine/MaxConstraint.cpp#L29-L32

MatthewDaggitt commented 6 months ago

Okay after much suffering I think I've finally got libtorch included in our build process without breaking everything :tada: Can be found on the branch use-pytorch.

The whole experience has really made me really appreciate the sane package managers and hygenic scoping of more modern languages. The next step is now to try and actually use Torch to handle the tensors...

MatthewDaggitt commented 6 months ago

I was over optimistic in my last comment. Unfortunately it seems like the act of simply adding the import #include <torch/torch.h> to the header file, degrades the performance of the Marabou tests by a factor of about 100x on my machine.

@wu-haoze any chance you could confirm its the same on your machine? If you pull the branch use-pytorch and build, and then uncomment lines 23 and 25 in OnnxParser.h and then build again?

MatthewDaggitt commented 6 months ago

Hmm you don't even have to include it in the header to trigger the problem. Simply linking against it is enough to degrade performance. I've asked a question on stack overflow here:

https://stackoverflow.com/questions/78125867/why-might-the-single-act-of-linking-against-a-library-drastically-degrade-the-ru

:crossed_fingers: we get a useful answer

MatthewDaggitt commented 6 months ago

My personal contact also says:

statically, linked libraries, may have some initialization code, and that may cause the issues that you were seeing, which you could investigate a little bit with a debugger, but also which you could try to circumvent, or at least mitigate by simply linking to the library at runtime

wu-haoze commented 6 months ago

I was over optimistic in my last comment. Unfortunately it seems like the act of simply adding the import #include <torch/torch.h> to the header file, degrades the performance of the Marabou tests by a factor of about 100x on my machine.

@wu-haoze any chance you could confirm its the same on your machine? If you pull the branch use-pytorch and build, and then uncomment lines 23 and 25 in OnnxParser.h and then build again?

I tried it on my end and observed the same phenomenon.

wu-haoze commented 6 months ago

My personal contact also says:

statically, linked libraries, may have some initialization code, and that may cause the issues that you were seeing, which you could investigate a little bit with a debugger, but also which you could try to circumvent, or at least mitigate by simply linking to the library at runtime

Does linking to the library at runtime requires building a dynamically linked pytorch library?

MatthewDaggitt commented 6 months ago

I suspect so, but I'm unsure if that would solve the problem? Surely we'd have a similar bootup time whenever we imported the module instead? Which we always will thanks to the proposed dependency on ONNX...

MatthewDaggitt commented 6 months ago

Okay I've been playing around with this some more, and I haven't found a way of avoiding the start-up cost. Having said that I've failed to get PyTorch to build from source. What are your opinions @wu-haoze? Is increasing the flat start-up time of Marabou by 2 seconds acceptable?

barrettcw commented 6 months ago

I don't think we want to increase the start-up time by 2 seconds.

wu-haoze commented 6 months ago

I agree with Clark. Several use cases of Marabou involve solving short running queries many many times, and a 2 second start up time would be too much. :(

在 2024年3月20日,18:28,MatthewDaggitt @.***> 写道:



Okay I've been playing around with this some more, and I haven't found a way of avoiding the start-up cost. Having said that I've failed to get PyTorch to build from source. What are your opinions @wu-haozehttps://github.com/wu-haoze? Is increasing the flat start-up time of Marabou by 2 seconds acceptable?

— Reply to this email directly, view it on GitHubhttps://github.com/NeuralNetworkVerification/Marabou/issues/754#issuecomment-2011019962, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ACRZF236DE2DDBU3R6C46B3YZIZUNAVCNFSM6AAAAABDLXKCFKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMJRGAYTSOJWGI. You are receiving this because you were mentioned.Message ID: @.***>