keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
62.13k stars 19.49k forks source link

Multi-output model training broken due to unexpected tree.flatten behavior #20346

Open gustavoeb opened 1 month ago

gustavoeb commented 1 month ago

TLDR; as of Keras 3.5 training a functional model with multiple outputs, when passing loss and y_true as dicts, is broken. Seems like tree.flatten is re-ordering entries in the y_true dict.

Keras version: 3.5+ Backend: all

Repro code:

from numpy.random import randn
from keras import layers, Model

if __name__ == "__main__":
    input_a = layers.Input((10,), name="a")
    input_b = layers.Input((5,), name="b")
    output_c = layers.Reshape([5,4], name="c")(layers.Dense(20)(input_a))
    output_d = layers.Dense(1, name="d")(input_b)

    model = Model([input_a, input_b], [output_d, output_c])
    model.compile(loss={"d":"mse", "c":"mae"}, optimizer="adam")

    model.fit(
        {"a":randn(32, 10), "b":randn(32, 5)},
        {"d":randn(32, 1), "c":randn(32, 5, 4)},
    )

To the best of my understanding this is pretty valid code, and until 3.4 it worked. Now it seems like this line is re-ordering y_true: https://github.com/keras-team/keras/blob/acceb5a995b82a006bb174b396844afc9f1fd052/keras/src/trainers/compile_utils.py#L536 For this one example y_true becomes [c,d] while y_pred is [d,c]. It seemed alphabetical in my couple attempts.

nicolaspi commented 1 month ago

The following code, that aligns the model's output and the loss structure, works on the current master branch:

from numpy.random import randn
from keras import layers, Model

input_a = layers.Input((10,), name="a")
input_b = layers.Input((5,), name="b")
output_c = layers.Reshape([5, 4], name="c")(layers.Dense(20)(input_a))
output_d = layers.Dense(1, name="d")(input_b)

model = Model({"a": input_a, "b": input_b}, {"d": output_d, "c": output_c})
model.compile(loss={"d": "mse", "c": "mae"}, optimizer="adam")

model.fit({"a": randn(32, 10), "b": randn(32, 5)},
  {"d": randn(32, 1), "c": randn(32, 5, 4)}, )
gustavoeb commented 1 month ago

thanks @nicolaspi, that fixes the issue for me, and I actually like this pattern better. this is still a bug since the workflow with different data structures is supported; although, at least for the inputs I'm now getting a warning about the structure mismatch. maybe consider deprecating the support of different structures in the future?

nicolaspi commented 1 month ago

Hi @gustavoeb, you should receive both a warning and an exception, please check this gist