mlexchange / mlex_dlsia_segmentation_prototype

Other
3 stars 3 forks source link

Refactor train #28

Open TibbersHao opened 2 months ago

TibbersHao commented 2 months ago

This PR works on the issue #26 to revamp the current code structure to be more unit test friendly and modular. The work is 80% finished and pytests are ready to be executed.

Major Feature Upgrade

  1. Updated training script which utilizes small functions for each step and common functions have been saved as utility functions.
  2. Updated docstrings for each function.
  3. Each function has been accompanied by one or several pytests
  4. Tests for model building and training

Minor Feature Upgrade

  1. Updated QLTY version

Known Issue

  1. The model loading function which leverages DLSIA's pre-built loading function apparently has mismatched model weights in state_dict( ) comparing to the trained model. Cause unknown and investigation still under the way.
  2. Package compatibility. At this moment the pydantic version has been temporary degraded to 1.10.15 in order to mitigate the imcompatibility of an older version of tiled[all] in dev requirements. Updating tiled version will require a rewrite of the array space allocation function using DataSource from tiled, and this will be in the scope of a separate PR.

To Do's

  1. Adding quick inference step right after training
  2. Breaking done the long crop_seg_save( ) function used in inference script to manageable small functions
  3. Adding pytests for tiled array allocation function
  4. Finishing inference script
  5. Format cleaning
  6. Deleting no longer used functions and merge seg_utils.py into utils.py for clarity
dylanmcreynolds commented 2 months ago

I ran pulled this down and ran pytest. I see three pytest errors. The build is also failing with a flake8 error. You should run both flake8 . and pytest before committing. Since this project uses pre-commit, you can also do the following:

pre-commit install

Then, if you want to test things out without committing, run:

pre-commit run --all-files

This is the equivalent of doing what pre-commit does when you commit, which includes running flake8 and black.

One other tip...after you run, black will reformat your code. You can then run the following to add those changes to your next commit:

git add .
TibbersHao commented 2 months ago

I ran pulled this down and ran pytest. I see three pytest errors. The build is also failing with a flake8 error. You should run both flake8 . and pytest before committing. Since this project uses pre-commit, you can also do the following:

pre-commit install

Then, if you want to test things out without committing, run:

pre-commit run --all-files

This is the equivalent of doing what pre-commit does when you commit, which includes running flake8 and black.

One other tip...after you run, black will reformat your code. You can then run the following to add those changes to your next commit:

git add .

Sounds good. Initially I was planning to run the pre-commit check and fix all format related changes when I wrap up the inference part, but I will include these changes with my next commit.

On the pytest side, could you send me the errors you got from pytest? I was expecting only one known error but you got three instead.

TibbersHao commented 1 month ago

Nice progress on the refactoring and great to see the variety of test cases being included.

I've reviewed the test cases, and overall, there seems to be a bit of a disconnect between the test cases and the functions/functionality actually being tested. It appears that the functions to be tested are executed within fixtures, and the test cases are for the most part consisting of only testing for expected outcomes. This makes it a little hard to follow what each test case ensures, as one first has to inspect which fixture was used and which function was called within that fixture.

To improve clarity and maintainability, I recommend to use fixtures solely for setting up the input for the functions to be tested, explicitly call the function to be tested within the test case, and then check for expected outcomes as it is done now.

Thanks for the suggestion. Regarding the use case of fixture here, Dylan and I had a discussion last Friday, and it seems moving along with the current code structure would be a safe play, as the way I developed these new utility functions are used in a chain manner in the train script, so that there will be more duplicates if I have them directly in the test functions. I have realized this is not the conventional way of using fixture, and I am willing to have more discussion here.

I would additionally recommend to look into whether a fixture is the right way to parametrize each test. I some cases, the use of pytest.mark.parametrize may suffice, or be advantageous as multiple parameter combinations and outcomes can be tested.

I will look into that, thanks for pointing it out.

In regards to the refactoring, there seems to be still some work you are planning to do with the listed todos and some larger sections of commented out code. Are you planning to address these under this PR, or create a separate one for inference? Could you clarify what the difference between functions in utils and functions in seg_utils is?

What I am planning to do within this PR is to finish the inference script (both quick inference right after train and full inference) so that those large commented out codes will be gone. This would be necessary to wrap up this PR and I am not planning to do more feature upgrades regarding the tiled functionality, as I do not want to add up the complexity of this PR (which is already a bit too big).

In the first round of development, I put those functions which are directly targeting to the crucial steps of segmentation into the seg_utils and those more general ones into the utils. With the new diagram and refactoring, I see no reason of keeping two different files as it will cause confusion, and I am merging all utility functions into the utils for clarity. So seg_utils will be gone.

dylanmcreynolds commented 1 month ago

The more I think about it, the more I think you should address the fixtures. I think as a general rule, fixtures are for reusable setup code, not the code that you're testing.

A well-written test of an API like yours can serve as documentation for how to use your API. If you have a test that shows the flow of the calls that your API requires, people reading it can easily see what do to. Hiding code that is being tested in fixtures makes that a harder. I think you should rearrange this so that calls that are being tested are not hidden in confest.py.

TibbersHao commented 1 month ago

The more I think about it, the more I think you should address the fixtures. I think as a general rule, fixtures are for reusable setup code, not the code that you're testing.

A well-written test of an API like yours can serve as documentation for how to use your API. If you have a test that shows the flow of the calls that your API requires, people reading it can easily see what do to. Hiding code that is being tested in fixtures makes that a harder. I think you should rearrange this so that calls that are being tested are not hidden in confest.py.

Sounds good to me, and will do after wrapping up the function development.

TibbersHao commented 1 month ago

@dylanmcreynolds @Wiebke the latest commit covered 2 upgrades for the pytest:

Please give it another check to see if you would like to make any comments, I believe this PR should be reviewed and merged before more features to be added, as it's already too big atm.

TibbersHao commented 1 month ago

As requested by @Wiebke , a change log to capture the difference for testing pipeline connections:

The Revamped version:

Training

  1. provide I/O and model parameters in yaml, following examples from the "example_yamls" folder. Note the structure and entries of the yaml file remains the same as the current version in main branch.
  2. Activate conda environment.
  3. run the training script: python src/train.py <yaml file path>, command remains the same as the current version.
  4. The training script will create a model saving directory at models_dir from the yaml file, then train the model and save it to the directory along with metrics. This part remains the same as the current version.
  5. Once the training from last step is finished, train.py will kick off a quick inference using the trained model on annotated images, and save the result in the seg_tiled_uri indicated in the yaml file.

Inference

  1. provide I/O and model parameters in yaml, following examples from the "example_yamls" folder. Note the structure and entries of the yaml file remains the same as the current version in main branch. mask_tiled_uri and mask_tiled_api_key are optional for this case.
  2. Activate conda environment.
  3. run the full inference script: python src/segment.py <yaml file path>, command remains the same as the current version.
  4. The inference script only covers full inference now.
  5. Result will be saved in the seg_tiled_uri indicated in the yaml file. This remains the same as the current version.
TibbersHao commented 3 weeks ago

Thanks for the feedback. Here is a summary for the plans to move this PR forward:

TiledDataset

validate_parameters

utils.py

build_network

Testing

Performance

phzwart commented 3 weeks ago

The multiple sigmoid / softmax is an issue

Please send me an image stack and associated annotation so I can build a notebook and see what is going on and how things perform

P

On Wed, Sep 4, 2024 at 4:39 PM Wiebke Köpp @.***> wrote:

@.**** requested changes on this pull request.

Thank you for your hard work on this and providing the high-level description.

I was actually mostly looking for information on what needs to change in other repositories. I gather quick inference does not rely on uid_retrieve being set to anything, and we now need to reduce the two function calls here, to a single one, i.e. use only segmentation.py#L77-L95 https://github.com/mlexchange/mlex_highres_segmentation/blob/82de8225022dcd2969b170e6be3895417489b862/callbacks/segmentation.py#L77-L95

Overall, this refactoring is great. The modularity of functions is greatly improved, making the code more organized and easier to navigate. I have a few suggestions and observations in regards to the refactoring and testing, some of which I think are important to incorporate under this PR.

TiledDataSet TiledDataSet was created as a subclass of torch.utils.data.Dataset with the intention of making it easier to construct data loaders. Since then the code has evolved. Neither the transform, nor the qlty-patching functionionality are still in use, and instead the class is used for retrieval of frames only. In some places where the class functionality could be used, Tiled clients are accessed directly. I think now is the time to remove any obsolete code and make sure the class reflects how we actually use it. While the new parameters make it very clear in which "mode" the class is currently used, they are in principle already captured by the presence of a mask. I suggest either simplify this class (remove all qlty, or creating dedicated classes for the different use cases.

validate_parameters This function passes input parameters to the respective correct parameter class and otherwise mostly relies on Pydantic's type validation. I think we are missing out on the opportunity to have all parameters validated after this function call. In the current version, we still cannot trust that all input parameters are actually valid as some of the validation happens only when networks are built (for activation, convolution, ...) or while training (for the criterion). I suggest restructuring the parameter/network classes such that: After they are instantiated, all parameters are actually validated (making use of: Pydantics validators https://docs.pydantic.dev/1.10/usage/validators/). I would further connect the build_network methods to the respective classes. We could make use of the Factory pattern to eliminate some lengthy if/then/else-statements based on the network name. When upgrading Tiled, we would need to follow the Pydantic migration guide to replace obsolete decorators https://docs.pydantic.dev/latest/migration/#changes-to-validators.

utils.py While the modularity of functions has greatly improved, many functions have been moved to utils, even though they are closely connected to specific steps or concepts. I suggest organizing them differently, perhaps grouping them according to suitable topics, maybe model_io, data_io, network, ...

Testing: The previously mentioned disconnect between fixtures and test cases remains. This leads to test cases to be less legible, and points to the test cases overall having some interdependency. I am concerned that the pattern of placing the actual functions to be tested in fixtures, may in the future just cause fixture creation to fail, rather than causing test cases to fail. I think the high-level concern here is that the test setup essentially does end-to-end testing, but in steps. The numbering of the test files further implies some dependency that I would recommend to steer away from. For example, any training or inference steps that require a TiledDataSet could be made independent of Tiled by mocking the class. You have clearly already put a lot of thought into the test setup and it is a challenging aspect of this project, but please revisit this, taking into account the importance of independent isolated testing. Additionally, the inclusion of the one set of bad parameters causes some complicated logic which is difficult to trace back due to the issue above. In functions that are intended for testing parameter validation testing, there is a check if a parameter is of type AssertionError, but within the test case it is not immediately clear why the passed parameter (a fixture) would be of that type. Consider creating a dedicated test case for failure, rather than skipping this case in all other tests. Also note that within the Github action, testing causes 206 warnings (some of which have been addressed in Tiled, #bluesky/tiled/676), some related to the test setup, and there are some important aspects that are not being tested (e.g. ensure_parent_containers). Furthermore, the testing and example YAMLs do not fully reflect how training and inference are run from our frontend application. This means issues like validate_parameters only working if a mask_uri is provided—which is not required in inference—aren’t being caught. Finally (and this may be one of the more controversial points of this review), I do not think all tests included right now are really necessary, e.g. functions that are essentially just getters (passing parameters to a parameter class and checking if they are still the same) or are close to trivial (find_device, or create-directory).

Performance: This is more for a future PR, perhaps in collaboration with @Giselleu https://github.com/Giselleu.

  • Inference normalizes data on cpu before moving the data to device. Should this be the transform that was removed from TiledDataSet?
  • segment_single_frame calls .eval().to(device) on the given network. I assume that later calls thus no longer cause a move of network weights to GPU? segment_single_frame moves and concatenates on CPU. This will cause stitching to happen on the CPU as well.
  • Training loads all data and masks into a numpy array each. Is there a limit for the number of slices this is feasible with?

Final things I noticed while taking a look at the overall current state, not necessarily this PR:

  • MSDNets have a final_layer parameter which we pass final_layer=nn.Softmax(dim=1) to, but all networks apply a final_layer within segment_single_frame. Does that mean we pass MSDNet results through a final_layer twice? Could this be another reason (aside from setting the number of layers being set too low) for our consistent sub-par MSDNet performance? (Typically everything is segmented as a single class). @phzwart https://github.com/phzwart
  • The functions for partial inference and full inference appear to have a large amount of overlap and only differ in the retrieval of network and device as well as the qlty setup which has already happened for partial inference.
  • build_network indicates that we support different data dimensionalities, but I do not see that reflected in testing.
  • The ensemble build methods make use of a lot of default parameters that we do not expose to the user. Is this intended? If the function you crafted here to setup networks does not exist in dlsia itself, should this functionality be moved there?
  • CrossEntropyLoss is the only loss tested for, but we have exposed alternative losses in the interface. However, specifying any loss without the parameters weights, ignore_index or size_average will cause errors. Should we support only CrossEntropyLoss?

I suggest we meet offline to discuss further and come up with a plan to tackle the observations above and/or record issues for future work.

— Reply to this email directly, view it on GitHub https://github.com/mlexchange/mlex_dlsia_segmentation_prototype/pull/28#pullrequestreview-2281580475, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADWIEE4U5VW4J57WYDSKI5TZU6K37AVCNFSM6AAAAABLLNKUIKVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDEOBRGU4DANBXGU . You are receiving this because you were mentioned.Message ID: <mlexchange/mlex_dlsia_segmentation_prototype/pull/28/review/2281580475@ github.com>

--

Peter Zwart Staff Scientist, Molecular Biophysics and Integrated Bioimaging Berkeley Synchrotron Infrared Structural Biology Biosciences Lead, Center for Advanced Mathematics for Energy Research Applications Lawrence Berkeley National Laboratories 1 Cyclotron Road, Berkeley, CA-94703, USA Cell: 510 289 9246