davidfox87 / eks-terraform-tfjob-katib-argo

3 stars 0 forks source link

Tensorboard logging #1

Open davidfox87 opened 1 year ago

davidfox87 commented 1 year ago

logdir = "logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S") tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

model = keras.models.Sequential([ keras.layers.Dense(16, input_dim=1), keras.layers.Dense(1), ])

model.compile( loss='mse', # keras.losses.mean_squared_error optimizer=keras.optimizers.SGD(learning_rate=0.2), )

print("Training ... With default parameters, this takes less than 10 seconds.") training_history = model.fit( x_train, # input y_train, # output batch_size=train_size, verbose=0, # Suppress chatty output; use Tensorboard instead epochs=100, validation_data=(x_test, y_test), callbacks=[tensorboard_callback], )

print("Average test loss: ", np.average(training_history.history['loss']))