Closed fryguy1013 closed 1 year ago
No one seems to have refactored this file, but if you're having trouble with model.predict
, would you mind providing a short code for us to reproduce?
var inputs = tf.keras.Input(shape: (6, 8, 3), name: "main_input");
var conv2d = tf.keras.layers.Conv2D(32, kernel_size: (3, 3),
activation: tf.keras.activations.Linear).Apply(inputs);
var valueHead = tf.keras.layers.Dense(units: 1, use_bias: false, activation: tf.keras.activations.Linear).Apply(conv2d);
var policyHead = tf.keras.layers.Dense(units: 8*6*3, use_bias: false, activation: tf.keras.activations.Linear).Apply(conv2d);
var model = tf.keras.Model(
inputs: inputs,
outputs: new Tensors(valueHead, policyHead),
"predictions"
);
var inputNN = np.zeros(6 * 8 * 3, TF_DataType.TF_FLOAT);
inputNN = np.reshape(inputNN, (1, 6, 8, 3));
var preds = model.predict(new Tensors(inputNN));
This was my test. I would expect preds
to be an enumerable with two elements in it.
And about the refactoring, I just mean that commit 271dcefc added a second method with the current behavior (just having tmp_batch_outputs[0] returned), and then it was refactored in 0ee50d3 to use the newly created method.
Looks good to me.@Oceania2018
I think someone with permissions needs to re-run the semantic check job after changing the name of the PR.
All the unit tests and examples has passed. Thanks a lot for your contribution. :) The automatic nightly release is available by adding https://www.myget.org/F/scisharp/api/v3/index.json
to your nuget source after merge (may take several minutes to wait for the action completes).
It seems like this is how the code worked before a refactoring, but I'm not sure.