tensorflow / neural-structured-learning

Training neural models with structured signals.
https://www.tensorflow.org/neural_structured_learning
Apache License 2.0
980 stars 189 forks source link

Introducing Val_data in Fit function #44

Closed pascoa-pand closed 3 years ago

pascoa-pand commented 4 years ago

Hello guys,

I was trying NSL out, but I was unable to run .fit with my validation data. What is the correct way to do it?

The following seems to be working just fine: adv_model.fit({'feature': Xtrain_array, 'label': Ytrain_array},epochs=20)

I tried converting the training and test arrays to a Dataset class using the following code but I'm getting an error "KeyError: 'feature'":

x_train=Xtrain_array y_train=Ytrain_array x_val=Xtest_array y_val=Ytest_array batch_size=32 train_data = tf.data.Dataset.from_tensor_slices({'input': x_train, 'label': y_train}).batch(batch_size) val_data = tf.data.Dataset.from_tensor_slices({'input': x_val, 'label': y_val}).batch(batch_size) val_steps = x_val.shape[0] / batch_size adv_model.fit(train_data, validation_data=val_data,validation_steps=val_steps, epochs=2, verbose=1)

csferng commented 4 years ago

@pascoa-pand, thanks for the question.

What is your Tensorflow and NSL version?

I am able to run .fit with validation data (set to a tf.data.Dataset object) using TF 2.1.0. The following code (based on the example here) works on Google Colab:

adv_model.fit(train_set_for_adv_model,
              validation_data=test_set_for_adv_model,
              validation_steps=10,
              epochs=2)

And the output includes validation losses:

Epoch 1/2
1875/1875 [==============================] - 22s 12ms/step - loss: 0.0700 - sparse_categorical_crossentropy: 0.0240 - sparse_categorical_accuracy: 0.9923 - adversarial_loss: 0.2296 - val_loss: 0.0820 - val_sparse_categorical_crossentropy: 0.0344 - val_sparse_categorical_accuracy: 0.9844 - val_adversarial_loss: 0.2382
Epoch 2/2
1875/1875 [==============================] - 17s 9ms/step - loss: 0.0574 - sparse_categorical_crossentropy: 0.0196 - sparse_categorical_accuracy: 0.9940 - adversarial_loss: 0.1892 - val_loss: 0.0837 - val_sparse_categorical_crossentropy: 0.0462 - val_sparse_categorical_accuracy: 0.9875 - val_adversarial_loss: 0.1875
csferng commented 3 years ago

Closing this issue for now. Feel free to reopen if the error occurs again.