scikit-learn-contrib / hiclass

A python library for hierarchical classification compatible with scikit-learn
BSD 3-Clause "New" or "Revised" License
114 stars 20 forks source link

Add support for allowing multi-dimensional inputs (ndim > 2) #minor #97

Closed ashishpatel16 closed 1 year ago

ashishpatel16 commented 1 year ago

Image input data has 3 dimensions (rgb) per image and is then a list of images => has 4 dimensions.

Currently, passing such inputs generates the following error -

Traceback (most recent call last):
  File "/home/paula/documents/mastersproject/hierarchical-explainability/animal_dataset/train_with_hiclass.py", line 8, in <module>
    my_classifier.fit(X_train, y_train)
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/hiclass/LocalClassifierPerParentNode.py", line 100, in fit
    super()._pre_fit(X, y, sample_weight)
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/hiclass/HierarchicalClassifier.py", line 138, in _pre_fit
    self.X_, self.y_ = self._validate_data(
                       ^^^^^^^^^^^^^^^^^^^^
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/sklearn/base.py", line 622, in _validate_data
    X, y = check_X_y(X, y, **check_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 1146, in check_X_y
    X = check_array(
        ^^^^^^^^^^^^
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 951, in check_array
    raise ValueError(
ValueError: Found array with dim 4. LocalClassifierPerParentNode expected <= 2.Traceback (most recent call last):
  File "/home/paula/documents/mastersproject/hierarchical-explainability/animal_dataset/train_with_hiclass.py", line 8, in <module>
    my_classifier.fit(X_train, y_train)
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/hiclass/LocalClassifierPerParentNode.py", line 100, in fit
    super()._pre_fit(X, y, sample_weight)
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/hiclass/HierarchicalClassifier.py", line 138, in _pre_fit
    self.X_, self.y_ = self._validate_data(
                       ^^^^^^^^^^^^^^^^^^^^
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/sklearn/base.py", line 622, in _validate_data
    X, y = check_X_y(X, y, **check_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 1146, in check_X_y
    X = check_array(
        ^^^^^^^^^^^^
  File "/home/paula/documents/mastersproject/hierarchical-explainability/.venv/lib/python3.11/site-packages/sklearn/utils/validation.py", line 951, in check_array
    raise ValueError(
ValueError: Found array with dim 4. LocalClassifierPerParentNode expected <= 2.

This solution simply supplies the parameter allow_nd=True in the fit pipeline and allow_nd=True, ensure_2d=False in the predict pipeline to ensure the support of multi dimensional inputs.

iwan-tee commented 1 year ago

your changes seem to be correct and be in a right way

mirand863 commented 1 year ago

Hi @ashishpatel16,

Thank you for the pull request! It looks good to me, but it is failing some linting tests. Can you please run black and commit again? It is also a good idea to setup the pre-commit according to the instructions in the contributing file.

codecov-commenter commented 1 year ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (0f9f955) 98.40% compared to head (c2b1180) 98.40%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #97 +/- ## ======================================= Coverage 98.40% 98.40% ======================================= Files 8 8 Lines 566 566 ======================================= Hits 557 557 Misses 9 9 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

ashishpatel16 commented 1 year ago

@mirand863 The errors have been resolved finally, could you merge it if everything looks alright?