ThanatosShinji / onnx-tool

A parser, editor and profiler tool for ONNX models.
https://pypi.org/project/onnx-tool/
MIT License
399 stars 52 forks source link

support group query in attention module without replicate weight #68

Closed cliffsong94 closed 8 months ago

cliffsong94 commented 8 months ago

Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

matmul inputs [1, 4, 8, 64, 64] query, [1, 4, 1, 64, 2048] key outputs [1, 4, 1, 64, 2048] which is wrong

Describe the solution you'd like A clear and concise description of what you want to happen. https://github.com/ThanatosShinji/onnx-tool/blob/125836d486a03cb3dbbef05f5412803ff86993db/onnx_tool/node.py#L688

change this line to batchshape = xshape[:-2] if len(xshape) >= len(wshape) else wshape[:-2]

ThanatosShinji commented 8 months ago

how about

batchshape = xshape[:-2] if volume(xshape[:-2]) >= volume(wshape[:-2]) else wshape[:-2]
cliffsong94 commented 8 months ago

how about

batchshape = xshape[:-2] if volume(xshape[:-2]) >= volume(wshape[:-2]) else wshape[:-2]

Yeah, it is even better

ThanatosShinji commented 8 months ago

fixed in this commit 7d8cca70