mlr-org / mlr3keras

Deep learning for mlr3
GNU Lesser General Public License v3.0
36 stars 2 forks source link

tabnet - ValueError: Number of groups (2) must be a multiple of the number of channels (13). #19

Closed JackyP closed 4 years ago

JackyP commented 4 years ago

Minimal example:

wine_task <- tsk("wine")
lrn_tabnet <- lrn("classif.tabnet")
lrn_tabnet$train(wine_task)
ValueError: Number of groups (2) must be a multiple of the number of channels (13).

This is due to tf-tabnet implementing group normalisation with a default of num_groups = 2, and the input dimension not being even and for this dataset it would be a prime number.

Due to mlr3keras doing the work to one hot features behind the scenes, it's not immediately obvious whether the input dimension will be odd or even.

Consequently I would argue we should set it to default at 1L rather than 2L so that the learner will train with default settings even though the upstream package defaults at 2, or to use batch_norm as per the original TabNet paper.

pfistfl commented 4 years ago

Seems sensible. I guess we should then give some more info on how to obtain valid settings for the parameter. Should we export get_tf_num_features?