mlr-org / mlr3keras

Deep learning for mlr3
GNU Lesser General Public License v3.0
36 stars 3 forks source link

Feature request: Tabnet #7

Closed JackyP closed 4 years ago

JackyP commented 5 years ago

Tabnet model by Google researchers (https://arxiv.org/abs/1908.07442) claims the model architecture "outperforms other neural network and decision tree variants on a wide range of tabular data learning datasets and yields interpretable feature attributions and insights into the global model behavior. "

There is a Python package for keras/tensorflow 2.0 (tabnet) and one for Pytorch (pytorch-tabnet).

Here is a short R POC of the iris example from the former:

library("reticulate")
library("tensorflow")
library("keras")
# keras::install_keras(extra_packages = c("tensorflow-hub", "tabnet==0.1.3"))

use_implementation("tensorflow")

tabnet <- import("tabnet") #0.1.3 could not get 0.1.4 to work filed issue

col_names = c('Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Petal.Width')

feature_columns <- lapply(col_names, function(x) { tf$feature_column$numeric_column(x)})

# Group Norm does better for small datasets
# n.b. name of class changes in 0.1.4 (that I could not get to work)
model = tabnet$TabNetClassification(feature_columns, num_classes=3,
                                    feature_dim=4, output_dim=4,
                                    num_decision_steps=2, relaxation_factor=1.0,
                                    sparsity_coefficient=1e-5, batch_momentum=0.98,
                                    virtual_batch_size=NULL, norm_type='group',
                                    num_groups=1)
model %>% compile(
  loss='categorical_crossentropy',
  optimizer = optimizer_adam(),
  metrics=c('accuracy')
)

iris["Sepal.Length"]

x <- lapply(col_names, function(x) { as.matrix(iris[x])})
names(x) <- col_names

y <- model.matrix(~ 0 + iris$Species)

model %>%
  fit(x, y, epochs=100, verbose=2)
pfistfl commented 5 years ago

Looks very cool, will try it out asap!

pfistfl commented 5 years ago

So apparently this breaks for model %>% fit(x, y, epochs=100, verbose=2, validation_split = 0.5), as the tensors can not be sliced?

Any ideas whether we can fix this?

pfistfl commented 5 years ago

The argument validation_split (generating a holdout set from the training data) is not supported when training from Dataset objects, since this features requires the ability to index the samples of the datasets, which is not possible in general with the Dataset API.

pfistfl commented 4 years ago

Fixed via #8