synnada-ai / mithril

Mithril: A Modular Machine Learning Library for Model Composability
Apache License 2.0
31 stars 8 forks source link

[BUG] Reduction models symbolic inference does not work properly #31

Closed aturker-synnada closed 6 days ago

aturker-synnada commented 1 week ago

Describe the Bug

Playing with reduction models, in some particular cases symbolic shape inference does not work properly

To Reproduce

from mithril.models import *
model = Model()
model += Buffer()(IOKey(shape=(("Var1", ...), "a")))
model += Mean(axis=TBD, keepdim=True)(axis=-1)
print(model.shapes)

prints: "{'$_Buffer_0_output': ['(V1, ...)', 'u1'], '$_Mean_1_output': ['(V1, ...)'], '$input': ['(V1, ...)', 'u1'], '$axis': None, '$keepdim': None}"

Expected Behavior

"$_Mean_1_output" shape should be "['(V1, ...)', 1]"

System Info