onnx / onnx-mlir

Representation and Reference Lowering of ONNX Models in MLIR Compiler Infrastructure
Apache License 2.0
770 stars 321 forks source link

Use DisposableElementsAttr for ZHigh constant propagation #3013

Open tungld opened 1 week ago

tungld commented 1 week ago

Quick experiment: the peak compile memory consumptions of #2917 and this PR when compiling the gpt2-large model for NNPA (744M parameters, the constant file's size is 3.2GB) are quite similar, both are about 9GB.

This patch contains the reverting code so it's no easy to follow. To ease the review, I merge all new changes (not the reverting code) into a single commit: https://github.com/onnx/onnx-mlir/pull/3013/commits/265ff9029c151f7a4a9473f23015414cd201f7e2. Please look at this commit for review.

AlexandreEichenberger commented 1 week ago

@tungld Just to understand the high level and without the class names. You are using Soren's approach of applying "logical" operations to the constants so that for example if we have <large-constant-tensor> * 2 + 1 we just keep the original <lage-constant-tensor> and tag along mult and add operators to the constant, so that if we need to materialize the multipied/added large constant tensor, we first apply these operations before generating the constant? And so, you added a stickify (presumably we never need an unstickify) operator?

tungld commented 1 week ago

You are using Soren's approach of applying "logical" operations to the constants

Yes, I extend it for ZHigh operations so the same approach is used for both ONNX and ZHigh until lowering to krnl. We can extend it to cover krnl operations but it needs more work and I didn't do it in this PR.

AlexandreEichenberger commented 4 days ago

Can you post here for ref the improvements you got, just for future reference purpose. Does not need to be super detailed. Thanks

imaihal commented 2 days ago

Can you post here for ref the improvements you got, just for future reference purpose. Does not need to be super detailed. Thanks

I put the measurement results of gpt2-large and Mistral-7b. In gpt2-large, the peak memory usage reduced from 8.9 GB to 7.4 GB, and compilation time becomes faster from 5 min 22sec to 4 min 30 sec. Left graph is current main, and right graph is PR3013.

image

In Mistral-7b, the peak memory usage reduced from 33.2 GB to 27.9 GB, and compilation time becomes faster from 17 min 4 sec to 13 min 58 sec. Left graph is current main, and right graph is PR3013.

image