sktime / sktime

A unified framework for machine learning with time series
https://www.sktime.net
BSD 3-Clause "New" or "Revised" License
7.99k stars 1.39k forks source link

[ENH] Class Design for BaseDeepNetwork #3190

Open AurumnPegasus opened 2 years ago

AurumnPegasus commented 2 years ago

Is your feature request related to a problem? Please describe. As discussed in Deep Learning mentoring meet of 05/08/2022:

Current Implementation: BaseDeepNetwork is a base class for creating neural networks. Each specific neural network is child of the BaseDeepNetwork. For example, CNNNetwork, CNTCNetwork, LSTMNetwork etc would be classes inheriting from the BaseDeepNetwork having a single function called build_model (which builds and returns the created keras network) For estimators, there exist specific BaseDeepClassifier and BaseDeepRegressor, inheriting from BaseClassifier and BaseRegressor respectively. Specific estimators like CNNClassifier inherit from BaseDeepClassifier, and within the init method create an object of the class CNNNetwork. Then, in the fit method, it gets the respective keras neural network by calling build_model method from the CNNNetwork object. The positive of this design is that when creating a CNNRegressor, we do not need to re-write the code for the main CNN, instead we just use CNNNetwork similar to how it is used in CNNClassifer. A point to note: CNNNetwork returns the keras network built with all except the output layer, which is added in the specific estimator like CNNClassifier

New Propositions (discussions with @fkiraly @ltsaprounis today, and @GuzalBulatova previously) Make BaseDeepClassifer inherit from both BaseClassifier and BaseDeepNetwork. In this case, BaseDeepClassifier will have all the methods and structure similar to other classifiers via BaseClassifier, and it can have methods specific to Deep Learning models via BaseDeepNetwork. This would reduce the amount of redundant code written across all DL models, and would lead to easier implementation of CustomNetworks by the users (when we move towards having that feature).

An example of redundant code occuring across all DL classes: Check PR #3128 , where I have implemented a save model and load model functionality for DL model separately (since pickling keras networks is troublesome and not easily implementable). Here, the same save and load function needs to be implemented in BaseDeepClassifer and BaseDeepRegressor and BaseDeepForecaster, which could have easily solved by having a common DL class all DL networks inherit from. I am sure there will be other such redundancies which we will find as I continue to migrate DL models from sktime-dl to sktime.

Currently, input and discussion is required to better design BaseDeepNetwork so as to minimize code redundancy and make a more intuitive structure.

AurumnPegasus commented 2 years ago

Created STEP document for the same with 2 proposed solutions: https://github.com/sktime/enhancement-proposals/pull/26