SciSharp / TensorFlow.NET

.NET Standard bindings for Google's TensorFlow for developing, training and deploying Machine Learning models in C# and F#.
https://scisharp.github.io/tensorflow-net-docs
Apache License 2.0
3.2k stars 514 forks source link

fix: predict with multiple outputs #1065

Closed fryguy1013 closed 1 year ago

fryguy1013 commented 1 year ago

It seems like this is how the code worked before a refactoring, but I'm not sure.

Wanglongzhi2001 commented 1 year ago

No one seems to have refactored this file, but if you're having trouble with model.predict, would you mind providing a short code for us to reproduce?

fryguy1013 commented 1 year ago
 var inputs = tf.keras.Input(shape: (6, 8, 3), name: "main_input");

var conv2d = tf.keras.layers.Conv2D(32, kernel_size: (3, 3),
    activation: tf.keras.activations.Linear).Apply(inputs);

var valueHead = tf.keras.layers.Dense(units: 1, use_bias: false, activation: tf.keras.activations.Linear).Apply(conv2d);
var policyHead = tf.keras.layers.Dense(units: 8*6*3, use_bias: false, activation: tf.keras.activations.Linear).Apply(conv2d);

var model = tf.keras.Model(
    inputs: inputs,
    outputs: new Tensors(valueHead, policyHead),
    "predictions"
);

var inputNN = np.zeros(6 * 8 * 3, TF_DataType.TF_FLOAT);
inputNN = np.reshape(inputNN, (1, 6, 8, 3));
var preds = model.predict(new Tensors(inputNN));

This was my test. I would expect preds to be an enumerable with two elements in it.

And about the refactoring, I just mean that commit 271dcefc added a second method with the current behavior (just having tmp_batch_outputs[0] returned), and then it was refactored in 0ee50d3 to use the newly created method.

Wanglongzhi2001 commented 1 year ago

Looks good to me.@Oceania2018

fryguy1013 commented 1 year ago

I think someone with permissions needs to re-run the semantic check job after changing the name of the PR.

AsakusaRinne commented 1 year ago

All the unit tests and examples has passed. Thanks a lot for your contribution. :) The automatic nightly release is available by adding https://www.myget.org/F/scisharp/api/v3/index.json to your nuget source after merge (may take several minutes to wait for the action completes).