automl / TabPFN

Official implementation of the TabPFN paper (https://arxiv.org/abs/2207.01848) and the tabpfn package.
http://priorlabs.ai
Apache License 2.0
1.22k stars 109 forks source link

running on 'glass' gives index error in cross_entropy #29

Closed amueller closed 1 year ago

amueller commented 1 year ago
df = valid_datasets[53]  # glass, id 41
print(ds[0], ds[1].shape)
xs, ys = ds[1].clone(), ds[2].clone()
eval_position = xs.shape[0] // 2
train_xs, train_ys = xs[0:eval_position], ys[0:eval_position]
test_xs, test_ys = xs[eval_position:], ys[eval_position:]
classifier = TabPFNClassifier()
print(classifier)
classifier.fit(train_xs, train_ys)
prediction_ = classifier.predict_proba(test_xs)
roc, ce = tabular_metrics.auc_metric(test_ys, prediction_), tabular_metrics.cross_entropy(test_ys, prediction_)
print('AUC', float(roc), 'Cross Entropy', float(ce))

There seems to be 7 classes but prediction_ only has 6 columns.

noahho commented 1 year ago

Hi Andreas, Thanks for this report - this was because of missing classes in the dataset, you probably noticed as well. (np.unique(ys) -> array([0, 1, 2, 4, 5, 6])). I added the sklearn LabelEncoder to our code to catch this. Best

amueller commented 1 year ago

I didn't catch the missing 3 actually, that makes sense :)