adriangb / scikeras

Scikit-Learn API wrapper for Keras.
https://www.adriangb.com/scikeras/
MIT License
242 stars 50 forks source link

RFC: Composable input/output pipeline #234

Open adriangb opened 3 years ago

adriangb commented 3 years ago

This is an attempt to collect ideas from various issues / PRs into a coherent framework and generalize our current input/output transformers.

@stsievert would you mind taking a look?

codecov-commenter commented 3 years ago

Codecov Report

Merging #234 (98d279b) into master (7c072fb) will not change coverage. The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #234   +/-   ##
=======================================
  Coverage   98.90%   98.90%           
=======================================
  Files           6        6           
  Lines         728      728           
=======================================
  Hits          720      720           
  Misses          8        8           

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 7c072fb...98d279b. Read the comment docs.

github-actions[bot] commented 3 years ago

📝 Docs preview for commit 98d279b at: https://www.adriangb.com/scikeras/refs/pull/234/merge/

adriangb commented 3 years ago

Thanks for this RFC – this is easier to review than a PR.

Good to know it helps! And thank you for reviewing.

What issue are you trying to solve with this RFC? What's an illustration of that issue?

This solves some user-facing issues, as well as allowing us to rework things internally to be cleaner and more flexible for users.

There are several different issues, but I think the clearest one is #160.

The general user request is to be able to use tf.data.Dataset.

We could hack this together as is (by hardcoding bypasses to BaseWrapper._validate_data and feature/target encoders), but even then we would lose attributes like n_features_in_ that are both part of our API and the Scikit-Learn API. This RFC solves this problem by making processing of tf.data.Dataset a first-class citizen, and even allowing us to move some of our validation/preprocessing to use tf.data.Dataset which, for example, would allow us to provide n_features_in_ and classes_ regardless of if the input was array-like or a tf.data.Dataset.

The other issues linked are generally around the theme of more validations (#106, #143) or figuring out how to clean up our current validations (#111, #209). These two issues are intertwined because:

  1. Every validation added has the chance of rubbing up against edge cases. If these validations aren't modular (eg if they're buried 3 layers deep into private method calls), we might block a use case we can't think of right now just to provide a nice error message.
  2. If we add more ad-hoc private methods that do validation/transformation (eg BaseWrapper._validate_data, BaseWrapper._check_model_compatibility) we end up with a mess of private methods where it's hard to understand the execution order, goal of the method, etc.

This RFC resolves this by providing a unified interface for these validations/transformations that is modular, composable and public. I envision it helping our users (since they can disable or add validation/transformation steps) as well as us (the developers) because we can organize our default validation/transformations better.