cerndb / dist-keras

Distributed Deep Learning, with a focus on distributed training, using Keras and Apache Spark.
http://joerihermans.com/work/distributed-keras/
GNU General Public License v3.0
623 stars 169 forks source link

Model predicts same value for all the features. #59

Open ankushreddy opened 6 years ago

ankushreddy commented 6 years ago

My model is predicting same value for all the features.

I am trying to predict the color of images. I have a dataframe or I created a dataframe with 12,000 thousand images. schema of the dataframe is.

  df.printSchema()
   root
       |-- Image_Id_new: string (nullable = true)
       |-- Color: string (nullable = true)
       |-- rawfeatures: vector (nullable = true)

  nb_classes = 11
  encoder = OneHotTransformer(nb_classes, input_col="Color", output_col="label_encoded")
  dataset_train = encoder.transform(final_train)
  dataset_test = encoder.transform(final_test)

  dataset_train = dataset_train.selectExpr("rawfeatures as features", "Color as label", "label_encoded")
 dataset_test = dataset_test.selectExpr("rawfeatures as features", "Color as label", "label_encoded")

Clear the dataset in the case you ran this cell before.

 dataset_train = dataset_train.select("features", "label", "label_encoded")
 dataset_test = dataset_test.select("features", "label", "label_encoded")

Allocate a MinMaxTransformer using Distributed Keras.

o_min -> original_minimum

n_min -> new_minimum

 transformer = MinMaxTransformer(n_min=0.0, n_max=1.0, \
                            o_min=0.0, o_max=250.0, \
                            input_col="features", \
                            output_col="features_normalized")

Transform the dataset.

 dataset_train = transformer.transform(dataset_train)
 dataset_test = transformer.transform(dataset_test)

  reshape_transformer = ReshapeTransformer("features_normalized", "matrix", (100, 100, 3))
 dataset_train = reshape_transformer.transform(dataset_train)
 dataset_test = reshape_transformer.transform(dataset_test)

 mlp = Sequential()
 mlp.add(Dense(11, input_shape=(30000,)))
  mlp.add(Activation('relu'))
  mlp.add(Dropout(0.2))
  mlp.add(Dense(128))
  mlp.add(Activation('relu'))
  mlp.add(Dropout(0.5))
  mlp.add(Dense(11))
  mlp.add(Activation('softmax'))
  mlp.summary()

 optimizer_mlp = 'adam'
 loss_mlp = 'categorical_crossentropy'

 def evaluate_accuracy(model, test_set, features="features_normalized_dense"):
     evaluator = AccuracyEvaluator(prediction_col="prediction_index", label_col="label")
      predictor = ModelPredictor(keras_model=model, features_col=features)
      transformer = LabelIndexTransformer(output_dim=nb_classes)
      test_set = test_set.select(features, "label")
      test_set = predictor.predict(test_set)
      test_set = transformer.transform(test_set)
      score = evaluator.evaluate(test_set)

        return score

  dataset_train = dataset_train.select("features_normalized", "matrix","label", "label_encoded")
  dataset_test = dataset_test.select("features_normalized", "matrix","label", "label_encoded")

  dense_transformer = DenseTransformer(input_col="features_normalized", 
               output_col="features_normalized_dense")
  dataset_train = dense_transformer.transform(dataset_train)
  dataset_test = dense_transformer.transform(dataset_test)
  dataset_train.repartition(num_workers)
  dataset_test.repartition(num_workers)

Assing the training and test set.

 training_set = dataset_train.repartition(num_workers)
  test_set = dataset_test.repartition(num_workers)

Cache them.

 training_set.persist(StorageLevel.MEMORY_AND_DISK_2)
 test_set.persist(StorageLevel.MEMORY_AND_DISK_2)

  print(training_set.count())

trainer = DOWNPOUR(keras_model=mlp, worker_optimizer=optimizer_mlp, loss=loss_mlp, num_workers=1, batch_size=32, communication_window=32, num_epoch=5, features_col="features_normalized_dense", label_col="label_encoded") trained_model = trainer.train(training_set)

  print("Training time: " + str(trainer.get_training_time()))

Training time: 235.8617208 print("Accuracy: " + str(evaluate_accuracy(trained_model, test_set))) Accuracy: 0.248927038627 evaluator = AccuracyEvaluator(prediction_col="prediction_index", label_col="label") predictor = ModelPredictor(keras_model=trained_model, features_col="features_normalized_dense") transformer = LabelIndexTransformer(output_dim=nb_classes) test_set = test_set.select("features_normalized_dense", "label") test_set = predictor.predict(test_set)

  test_set.select("label","prediction").show(truncate=False)

+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |label|prediction | +-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ |8 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |7 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |4 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |6 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |10 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |1 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |3 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |2 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |10 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |7 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |3 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |2 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |1 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| |0 |[0.14784927666187286,0.19311276078224182,0.08476026356220245,0.1478438526391983,0.05868959426879883,0.06460657715797424,0.04356149211525917,0.06898588687181473,0.0791180431842804,0.04349109157919884,0.06798115372657776]| +-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

ankushreddy commented 6 years ago

hi @JoeriHermans can you please look at this and let me know if am missing anything.

Am trying to achieve it for the past two weeks. Thank You for your help.