Open ankushreddy opened 6 years ago
My model is predicting same value for all the features.
I am trying to predict the color of images. I have a dataframe or I created a dataframe with 12,000 thousand images. schema of the dataframe is.
df.printSchema() root |-- Image_Id_new: string (nullable = true) |-- Color: string (nullable = true) |-- rawfeatures: vector (nullable = true) nb_classes = 11 encoder = OneHotTransformer(nb_classes, input_col="Color", output_col="label_encoded") dataset_train = encoder.transform(final_train) dataset_test = encoder.transform(final_test) dataset_train = dataset_train.selectExpr("rawfeatures as features", "Color as label", "label_encoded") dataset_test = dataset_test.selectExpr("rawfeatures as features", "Color as label", "label_encoded")
dataset_train = dataset_train.select("features", "label", "label_encoded") dataset_test = dataset_test.select("features", "label", "label_encoded")
transformer = MinMaxTransformer(n_min=0.0, n_max=1.0, \ o_min=0.0, o_max=250.0, \ input_col="features", \ output_col="features_normalized")
dataset_train = transformer.transform(dataset_train) dataset_test = transformer.transform(dataset_test) reshape_transformer = ReshapeTransformer("features_normalized", "matrix", (100, 100, 3)) dataset_train = reshape_transformer.transform(dataset_train) dataset_test = reshape_transformer.transform(dataset_test) mlp = Sequential() mlp.add(Dense(11, input_shape=(30000,))) mlp.add(Activation('relu')) mlp.add(Dropout(0.2)) mlp.add(Dense(128)) mlp.add(Activation('relu')) mlp.add(Dropout(0.5)) mlp.add(Dense(11)) mlp.add(Activation('softmax')) mlp.summary() optimizer_mlp = 'adam' loss_mlp = 'categorical_crossentropy' def evaluate_accuracy(model, test_set, features="features_normalized_dense"): evaluator = AccuracyEvaluator(prediction_col="prediction_index", label_col="label") predictor = ModelPredictor(keras_model=model, features_col=features) transformer = LabelIndexTransformer(output_dim=nb_classes) test_set = test_set.select(features, "label") test_set = predictor.predict(test_set) test_set = transformer.transform(test_set) score = evaluator.evaluate(test_set) return score dataset_train = dataset_train.select("features_normalized", "matrix","label", "label_encoded") dataset_test = dataset_test.select("features_normalized", "matrix","label", "label_encoded") dense_transformer = DenseTransformer(input_col="features_normalized", output_col="features_normalized_dense") dataset_train = dense_transformer.transform(dataset_train) dataset_test = dense_transformer.transform(dataset_test) dataset_train.repartition(num_workers) dataset_test.repartition(num_workers)
training_set = dataset_train.repartition(num_workers) test_set = dataset_test.repartition(num_workers)
training_set.persist(StorageLevel.MEMORY_AND_DISK_2) test_set.persist(StorageLevel.MEMORY_AND_DISK_2) print(training_set.count())
trainer = DOWNPOUR(keras_model=mlp, worker_optimizer=optimizer_mlp, loss=loss_mlp, num_workers=1, batch_size=32, communication_window=32, num_epoch=5, features_col="features_normalized_dense", label_col="label_encoded") trained_model = trainer.train(training_set)
print("Training time: " + str(trainer.get_training_time()))
Training time: 235.8617208 print("Accuracy: " + str(evaluate_accuracy(trained_model, test_set))) Accuracy: 0.248927038627 evaluator = AccuracyEvaluator(prediction_col="prediction_index", label_col="label") predictor = ModelPredictor(keras_model=trained_model, features_col="features_normalized_dense") transformer = LabelIndexTransformer(output_dim=nb_classes) test_set = test_set.select("features_normalized_dense", "label") test_set = predictor.predict(test_set)
test_set.select("label","prediction").show(truncate=False)
+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |label|prediction | +-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |8 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |7 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |4 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |6 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |10 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |1 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |3 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |2 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |10 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |7 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |3 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |2 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |1 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| +-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
hi @JoeriHermans can you please look at this and let me know if am missing anything.
Am trying to achieve it for the past two weeks. Thank You for your help.
My model is predicting same value for all the features.
I am trying to predict the color of images. I have a dataframe or I created a dataframe with 12,000 thousand images. schema of the dataframe is.
Clear the dataset in the case you ran this cell before.
Allocate a MinMaxTransformer using Distributed Keras.
o_min -> original_minimum
n_min -> new_minimum
Transform the dataset.
Assing the training and test set.
Cache them.
trainer = DOWNPOUR(keras_model=mlp, worker_optimizer=optimizer_mlp, loss=loss_mlp, num_workers=1, batch_size=32, communication_window=32, num_epoch=5, features_col="features_normalized_dense", label_col="label_encoded") trained_model = trainer.train(training_set)
Training time: 235.8617208 print("Accuracy: " + str(evaluate_accuracy(trained_model, test_set))) Accuracy: 0.248927038627 evaluator = AccuracyEvaluator(prediction_col="prediction_index", label_col="label") predictor = ModelPredictor(keras_model=trained_model, features_col="features_normalized_dense") transformer = LabelIndexTransformer(output_dim=nb_classes) test_set = test_set.select("features_normalized_dense", "label") test_set = predictor.predict(test_set)
+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |label|prediction | +-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |8 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |7 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |4 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |6 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |10 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |1 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |3 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |2 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |10 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |7 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |3 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |2 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |1 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| +-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+