drivendataorg / zamba

A Python package for identifying 42 kinds of animals, training custom models, and estimating distance from camera trap videos
https://zamba.drivendata.org/docs/stable/
MIT License
118 stars 27 forks source link

Set up correct OHE labels for subsets that use default model labels #236

Closed ejm714 closed 2 years ago

ejm714 commented 2 years ago

Fixes #234

229 introduced a bug whereby the new columns added to the labels file were in a different order than what is on the model. This PR fixes that by setting up the correct one hot encoded labels in the preprocess_labels validator rather than instantiate_model. Using the use_default_model_labels, we know whether the labels file should contain columns (with all zeroes) for species that are not present in the labels but are on the base model. Using a pd.Categorical before get_dummies allows us to generate these columns.

Running zamba train --config tests/assets/sample_train_config.yaml now works; the labels file has three species present in zamba but trains a model that outputs the full set of 32 labels.

netlify[bot] commented 2 years ago

Deploy Preview for silly-keller-664934 ready!

Name Link
Latest commit 0fdbc20af975645e989a7d22e8b2c365fbd9f304
Latest deploy log https://app.netlify.com/sites/silly-keller-664934/deploys/63334834199b35000819d6d2
Deploy Preview https://deploy-preview-236--silly-keller-664934.netlify.app
Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site settings.

github-actions[bot] commented 2 years ago

πŸš€ Deployed on https://deploy-preview-236--silly-keller-664934.netlify.app

codecov-commenter commented 2 years ago

Codecov Report

Merging #236 (0fdbc20) into master (0a894a5) will increase coverage by 0.0%. The diff coverage is 100.0%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #236 +/- ## ====================================== Coverage 87.2% 87.2% ====================================== Files 28 28 Lines 1961 1962 +1 ====================================== + Hits 1710 1711 +1 Misses 251 251 ``` | [Impacted Files](https://codecov.io/gh/drivendataorg/zamba/pull/236?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=drivendataorg) | Coverage Ξ” | | |---|---|---| | [zamba/models/model\_manager.py](https://codecov.io/gh/drivendataorg/zamba/pull/236/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=drivendataorg#diff-emFtYmEvbW9kZWxzL21vZGVsX21hbmFnZXIucHk=) | `84.3% <ΓΈ> (-0.5%)` | :arrow_down: | | [zamba/models/config.py](https://codecov.io/gh/drivendataorg/zamba/pull/236/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=drivendataorg#diff-emFtYmEvbW9kZWxzL2NvbmZpZy5weQ==) | `96.9% <100.0%> (+<0.1%)` | :arrow_up: | | [zamba/models/utils.py](https://codecov.io/gh/drivendataorg/zamba/pull/236/diff?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=drivendataorg#diff-emFtYmEvbW9kZWxzL3V0aWxzLnB5) | `100.0% <100.0%> (ΓΈ)` | |