Playing with reduction models, in some particular cases symbolic shape inference does not work properly
To Reproduce
from mithril.models import *
model = Model()
model += Buffer()(IOKey(shape=(("Var1", ...), "a")))
model += Mean(axis=TBD, keepdim=True)(axis=-1)
print(model.shapes)
Describe the Bug
Playing with reduction models, in some particular cases symbolic shape inference does not work properly
To Reproduce
prints: "{'$_Buffer_0_output': ['(V1, ...)', 'u1'], '$_Mean_1_output': ['(V1, ...)'], '$input': ['(V1, ...)', 'u1'], '$axis': None, '$keepdim': None}"
Expected Behavior
"$_Mean_1_output" shape should be "['(V1, ...)', 1]"
System Info