choshin84 / learning_memo

personal learning memo
0 stars 0 forks source link

Cross validation: Stratified CV #24

Open choshin84 opened 4 years ago

choshin84 commented 4 years ago

Tweet summary

Ensure to keep target variable salience in cross-validation sample

from sklearn.model_selection import cross_validate,  StratifiedKFold
from sklearn.metrics import accuracy_score, recall_score, precision_score
import lightgbm as lgb
# parameter setup
t0 = time.time()
kfold = 5
X = df.drop(columns=[target_column, 'Target'])
y = df.loc[:, ['Target']]
for c in X_cat.columns:
    col_type = X[c].dtype
    if col_type == 'object' or col_type.name == 'category':
        X[c] = X[c].astype('category')
model = lgb.LGBMClassifier(learning_rate=0.1, n_estimators=100,\
                                   max_depth=10, random_state=123, n_jobs=-1)
skf = StratifiedKFold(n_splits=kfold, random_state=42)
score = []
for i, (train_index, test_index) in enumerate(skf.split(X, y)):
    print('[Fold %d/%d]' % (i + 1, kfold))
    X_train, X_valid = X.iloc[train_index], X.iloc[test_index]
    y_train, y_valid = y.iloc[train_index], y.iloc[test_index]
    model.fit(X_train, y_train)
    accuracy = accuracy_score(y_valid, model.predict(X_valid))
    recall = recall_score(y_valid, model.predict(X_valid))
    precision = precision_score(y_valid, model.predict(X_valid))
    score.append([accuracy, recall, precision])
    value_col = 'Val_' + str(i)
    hoge = pd.DataFrame(sorted(zip(model.feature_importances_,X_train.columns)), columns=[value_col,'Feature'])
    if i == 0:
        feature_imp = hoge
    else:
        feature_imp = feature_imp.merge(hoge, on = ['Feature'], how = 'outer')
print("Time: ", int(time.time() - t0))
score = pd.DataFrame(score, columns = ['accuracy', 'recall', 'precision'])
score = score.melt()
print("Avg. accuracy: ", score[score['variable']=='accuracy'].mean()[0])
print("Std. accuracy: ", score[score['variable']=='accuracy'].std()[0])
feature_imp = pd.melt(feature_imp, id_vars=['Feature'])
# plot
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
sns.boxplot(x='variable', y='value', data=score)
plt.subplot(1, 2, 2)
sns.boxplot(x='Feature', y='value', data=feature_imp)
plt.xticks(rotation='vertical')
plt.tight_layout()
plt.show()