tensorflow / decision-forests

A collection of state-of-the-art algorithms for the training, serving and interpretation of Decision Forest models in Keras.
Apache License 2.0
660 stars 110 forks source link

Makes it easy to run cross-validations on small datasets #17

Open achoum opened 3 years ago

achoum commented 3 years ago

Decision Forests work well on small datasets where cross-validation is commonly used. It would be valuable to easily run cross-validations and report cross-validated related metrics (evaluation metrics, confidence interfaces, statistical tests, etc.).

Kareem-negm commented 3 years ago

Hi @achoum I found this, I think it describes what you are saying Hope it will be useful to you Here is a verbose example of 10-cross validation of TF-DF using sklearn.

from sklearn.model_selection import KFold
import numpy as np

accuraties_per_fold = [] # Test accuracy on the individual folds.

# Run a 10-folds cross-validation.
for  fold_idx, (train_indices, test_indices) in enumerate(KFold(n_splits=10, shuffle=True).split(all_df)):

  print(f"Running fold {fold_idx+1}")

  # Extract the training and testing examples.
  sub_train_df = all_df.iloc[train_indices]
  sub_test_df = all_df.iloc[test_indices]

  # Convert the examples into tensorflow datasets.
  sub_train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(sub_train_df, label="income")
  sub_test_df = tfdf.keras.pd_dataframe_to_tf_dataset(sub_test_df, label="income")

  # Train the model.
  model = tfdf.keras.GradientBoostedTreesModel()
  model.fit(sub_train_ds, verbose=False)

  # Evaluate the model.
  model.compile(metrics=["accuracy"])
  evaluation = model.evaluate(sub_test_df, return_dict=True, verbose=False)
  print(f"Evaluation {evaluation}")

  accuraties_per_fold.append(evaluation["accuracy"])

print(f"Cross-validated accuracy: {np.mean(accuraties_per_fold)}")

Output:

Evaluation {'loss': 0.0, 'accuracy': 0.8780701756477356}
Running fold 2
Evaluation {'loss': 0.0, 'accuracy': 0.8833333253860474}
Running fold 3
Evaluation {'loss': 0.0, 'accuracy': 0.8841597437858582}
Running fold 4
Evaluation {'loss': 0.0, 'accuracy': 0.8692408800125122}
Running fold 5
Evaluation {'loss': 0.0, 'accuracy': 0.8679245114326477}
Running fold 6
Evaluation {'loss': 0.0, 'accuracy': 0.8639754056930542}
Running fold 7
Evaluation {'loss': 0.0, 'accuracy': 0.8745063543319702}
Running fold 8
Evaluation {'loss': 0.0, 'accuracy': 0.8679245114326477}
Running fold 9
Evaluation {'loss': 0.0, 'accuracy': 0.8609039187431335}
Running fold 10
Evaluation {'loss': 0.0, 'accuracy': 0.8613426685333252}
Cross-validated accuracy: 0.8711381494998932  

Credit: @Mathieu