Closed cliffsong94 closed 8 months ago
how about
batchshape = xshape[:-2] if volume(xshape[:-2]) >= volume(wshape[:-2]) else wshape[:-2]
how about
batchshape = xshape[:-2] if volume(xshape[:-2]) >= volume(wshape[:-2]) else wshape[:-2]
Yeah, it is even better
fixed in this commit 7d8cca70
Is your feature request related to a problem? Please describe. A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
matmul inputs [1, 4, 8, 64, 64] query, [1, 4, 1, 64, 2048] key outputs [1, 4, 1, 64, 2048] which is wrong
Describe the solution you'd like A clear and concise description of what you want to happen. https://github.com/ThanatosShinji/onnx-tool/blob/125836d486a03cb3dbbef05f5412803ff86993db/onnx_tool/node.py#L688
change this line to batchshape = xshape[:-2] if len(xshape) >= len(wshape) else wshape[:-2]