wri / plantation_classifier

This research aims to spatially differentiate planted trees from natural trees using a transfer learning approach for image segmentation.
0 stars 0 forks source link

Migration to DVC #10

Closed jessicarose00 closed 4 months ago

jessicarose00 commented 5 months ago

This PR migrates the code base to an improved method for versioning data and models with the use of dvc. Dvc enables us to better track and save data and ML models, create and switch between versions of those data and models, and compare metrics among experiments. This PR restructures all scripts associated with model training and modifies those used for deployment.

General Changes

The repository structure is modified to include a new training pipeline with 5 stages:

  1. stage_load_data
  2. stage_prep_features
  3. stage_select_and_tune
  4. stage_train_model
  5. stage_evaluate_model

Parameters are extracted to a dedicated configuration file (params.yaml) and reusable code is migrated into separate modules for each stage. The experiment pipeline is automated with dvc and can be reproduced according to changes in the params.yaml file. The dvc.yaml file configures the machine learning workflow by listing the dependencies and outputs of each stage. The README.md is updated to document the new repository structure.

Method Changes

In addition to the restructuring of the code base, a few methodological changes are incorporated into this PR and listed below.

Specific Requests

@JohnMBrandt could you please review and confirm any edits required to the use of a random seed in the data split, here. @JohnMBrandt could you please review and confirm the transition to a subset median for ARD, here. @rlrognstad could you please review and confirm the use of Catboost feature importance to replace a SHAP explainer, here. Otherwise, general review and comments are more than welcome!

JohnMBrandt commented 4 months ago

The Yaml and config files look great --

for x in range(subsample):
                for i in indices:

med_ard = np.median(varied_median, axis=0).astype(np.float32) makes a duplicate of varied median, then saves the median to a new array as np.float64, then casts it to np.float32, which if the original array is np.float32, this takes 3x as much ram, and increments the ref count of varied_median, which does not get used again but now has a ref count and will not be cleared by the gc since its child is returned. This uses 3x less ram and does not have a ref count problem: varied_median = np.median(np.float32(varied_median), axis = 0, overwrite_input = True).

jessicarose00 commented 4 months ago

Thanks for the review! To respond to your comments:

I will review and update the areas you've flagged where duplicative copies are being made.

jessicarose00 commented 4 months ago

I've made the following updates to src/features/create_xy.py:

Other additional edits: