cpalfonso / stellar-data-mining

Scripts to extract subduction-related data for mineral exploration data mining
1 stars 1 forks source link

cross validation notebook - empty data? #6

Open RichardScottOZ opened 2 months ago

RichardScottOZ commented 2 months ago

With use_extracted_data = True

image

RichardScottOZ commented 2 months ago

Abbreviation problem

image

RichardScottOZ commented 2 months ago
stratify = (
    data["region"]
    + "_"
    + data["label"]
)
data_train, data_test = train_test_split(
    data,
    test_size=0.2,
    random_state=random_seed,
    shuffle=True,
    stratify=stratify,
)
data_test = data_test[data_test["label"].isin({"positive", "negative"})]
#indices_NA = np.where(data_test["region"] == "NAm")[0]
#indices_SA = np.where(data_test["region"] == "SAm")[0]
indices_NA = np.where(data_test["region"] == "North America")[0]
indices_SA = np.where(data_test["region"] == "South America")[0]

train_pu = data_train[
    data_train["label"].isin({"positive", "unlabeled", "unlabelled"})
]
x_train_pu, y_train_pu = get_xy(train_pu)
train_svm = data_train[
    data_train["label"].isin({"positive", "negative"})
]
x_train_svm, y_train_svm = get_xy(train_svm)

x_test, y_test = get_xy(data_test)

model_pu = clone(model_pu)
model_pu.fit(x_train_pu, y_train_pu)

model_svm = clone(model_svm)
model_svm.fit(x_train_svm, y_train_svm)

probs_pu = model_pu.predict_proba(x_test)[:, 1]
probs_pu_NA = probs_pu[indices_NA]
probs_pu_SA = probs_pu[indices_SA]

probs_svm = model_svm.predict_proba(x_test)[:, 1]
probs_svm_NA = probs_svm[indices_NA]
probs_svm_SA = probs_svm[indices_SA]

probs_dict = {
    "All": {"PU": probs_pu, "SVM": probs_svm},
    "NA": {"PU": probs_pu_NA, "SVM": probs_svm_NA},
    "SA": {"PU": probs_pu_SA, "SVM": probs_svm_SA},
}
kwargs_dict = {
    "PU": {"color": "blue", "linestyle": "solid"},
    "SVM": {"color": "orange", "linestyle": "dashed"},
}

pd.concat(
    (
        data_train.assign(set="train"),
        data_test.assign(set="test"),
    )
).groupby(["set", "region", "label"]).size()