Open achoum opened 3 years ago
Hi @achoum I found this, I think it describes what you are saying Hope it will be useful to you Here is a verbose example of 10-cross validation of TF-DF using sklearn.
from sklearn.model_selection import KFold
import numpy as np
accuraties_per_fold = [] # Test accuracy on the individual folds.
# Run a 10-folds cross-validation.
for fold_idx, (train_indices, test_indices) in enumerate(KFold(n_splits=10, shuffle=True).split(all_df)):
print(f"Running fold {fold_idx+1}")
# Extract the training and testing examples.
sub_train_df = all_df.iloc[train_indices]
sub_test_df = all_df.iloc[test_indices]
# Convert the examples into tensorflow datasets.
sub_train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(sub_train_df, label="income")
sub_test_df = tfdf.keras.pd_dataframe_to_tf_dataset(sub_test_df, label="income")
# Train the model.
model = tfdf.keras.GradientBoostedTreesModel()
model.fit(sub_train_ds, verbose=False)
# Evaluate the model.
model.compile(metrics=["accuracy"])
evaluation = model.evaluate(sub_test_df, return_dict=True, verbose=False)
print(f"Evaluation {evaluation}")
accuraties_per_fold.append(evaluation["accuracy"])
print(f"Cross-validated accuracy: {np.mean(accuraties_per_fold)}")
Output:
Evaluation {'loss': 0.0, 'accuracy': 0.8780701756477356}
Running fold 2
Evaluation {'loss': 0.0, 'accuracy': 0.8833333253860474}
Running fold 3
Evaluation {'loss': 0.0, 'accuracy': 0.8841597437858582}
Running fold 4
Evaluation {'loss': 0.0, 'accuracy': 0.8692408800125122}
Running fold 5
Evaluation {'loss': 0.0, 'accuracy': 0.8679245114326477}
Running fold 6
Evaluation {'loss': 0.0, 'accuracy': 0.8639754056930542}
Running fold 7
Evaluation {'loss': 0.0, 'accuracy': 0.8745063543319702}
Running fold 8
Evaluation {'loss': 0.0, 'accuracy': 0.8679245114326477}
Running fold 9
Evaluation {'loss': 0.0, 'accuracy': 0.8609039187431335}
Running fold 10
Evaluation {'loss': 0.0, 'accuracy': 0.8613426685333252}
Cross-validated accuracy: 0.8711381494998932
Credit: @Mathieu
Decision Forests work well on small datasets where cross-validation is commonly used. It would be valuable to easily run cross-validations and report cross-validated related metrics (evaluation metrics, confidence interfaces, statistical tests, etc.).