cog-imperial / OMLT

Represent trained machine learning models as Pyomo optimization formulations
Other
257 stars 56 forks source link

Support for scikit neural networks and renaming omlt.onnx.py #62

Open joshuahaddad opened 2 years ago

joshuahaddad commented 2 years ago

This code adds an interface for scikit learn MLPRegressor() objects via the sklearn2onnx library and support for scikit offset scaling objects. All scikit scaling objects which do linear scaling are supported including StandardScaler, MaxAbsScaler, MinMaxScaler, and RobustScaler.

Some changes were required to the onnx_parser to handle the different conventions created by sklearn2onnx such as the biases being stored as (1, n) matrices instead of (n,) vectors.

Additionally, since sklearn_reader.py imports the omlt onnx reader it was renamed to onnx_reader.py instead of onnx.py as this leads to a circular import issue since the naming omlt.onnx.py conflicts with the global onnx library.

codecov[bot] commented 2 years ago

Codecov Report

Merging #62 (83f96bc) into main (683caa7) will decrease coverage by 0.04%. The diff coverage is 93.18%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #62      +/-   ##
==========================================
- Coverage   94.15%   94.10%   -0.05%     
==========================================
  Files          24       25       +1     
  Lines        1231     1272      +41     
  Branches      186      192       +6     
==========================================
+ Hits         1159     1197      +38     
- Misses         42       43       +1     
- Partials       30       32       +2     
Impacted Files Coverage Δ
src/omlt/io/onnx_reader.py 85.71% <ø> (ø)
src/omlt/io/sklearn_reader.py 92.10% <92.10%> (ø)
src/omlt/io/__init__.py 100.00% <100.00%> (ø)
src/omlt/io/onnx_parser.py 95.06% <100.00%> (+0.06%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 683caa7...83f96bc. Read the comment docs.