mlr-org / mlr3keras

Deep learning for mlr3
GNU Lesser General Public License v3.0
36 stars 3 forks source link

Add support for TabNet #8

Closed pfistfl closed 4 years ago

pfistfl commented 5 years ago

Requested in #7

This is a very generic POC, I think this will require a little more thinking, to make stuff beautiful.

JackyP commented 4 years ago

Thinking about this a bit more - is the flexibility of the x_transform and y_transform needed, or should it be more constrained?

The reference docs only have a limited number of input formats into fit

https://keras.rstudio.com/reference/fit.html

Vector, matrix, or array of training data (or list if the model has multiple inputs). If all inputs in the model are named, you can also pass a list mapping input names to data.

The default case handles the matrix-like, and the tabnet case is the list-like

Could the function just check

length(model$input_names)==1

and use that to pick the matrix-like or list-like formats by default I wonder?

pfistfl commented 4 years ago

Hey, thanks for the comments.

I was hoping I can generalize that! I will try to come up with a prototype at the end of the week.

The y transform may potentially be handy for multi-output models. Currently, it does not seem like mlr3 is designed for multi-task regression, but the logic makes sense. A test case for the y_transform may be helpful for coverage. Yes, and I think this will come soon(ish). Will do!

pfistfl commented 4 years ago

Hey, I did a major refactoring. This lets me do a lot of things in way less code. (I hope stuff does not get to complicated as a result).

This now includes tabnet and regression.

codecov-io commented 4 years ago

Codecov Report

Merging #8 into master will increase coverage by 2.88%. The diff coverage is 96.84%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master       #8      +/-   ##
==========================================
+ Coverage   92.74%   95.62%   +2.88%     
==========================================
  Files           5        7       +2     
  Lines         193      366     +173     
==========================================
+ Hits          179      350     +171     
- Misses         14       16       +2
Impacted Files Coverage Δ
R/LearnerRegrKeras.R 100% <100%> (ø)
R/KerasArchitecture.R 95.45% <95.45%> (ø)
R/LearnerTabNet.R 96.36% <96.36%> (ø)
R/LearnerClassifKeras.R 98.14% <96.42%> (-1.86%) :arrow_down:
R/LearnerKerasFF.R 96.51% <96.51%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 2d43cea...dfd0ffe. Read the comment docs.

pfistfl commented 4 years ago

@JackyP Any thoughts? I feel I might have made things more complicated than need be. If you feel this is alright, we can merge this and try to get PR #12 merged aswell?

JackyP commented 4 years ago

It does feel like it could be simplified a little, but it was not too hard to work with. Keen to hear your views on #12 as well.