ASUS-AICS / LibMultiLabel

A library for multi-class and multi-label classification
MIT License
152 stars 30 forks source link

Alter the Behaviors of ModelCheckpoint Callback #344

Closed donglihe-hub closed 10 months ago

donglihe-hub commented 10 months ago

Motivation

Lightning saves model to disk in the main process. In other words, training will stall until the checkpoint is saved to the disk. This is not a problem when models are small. However, as the model size grows, it becomes more time-consuming to save to disks. This becomes a bottleneck when training models like AttentionXML on extreme multilabel datasets like AmazonCat-13K and Amazon-670K, where AttentionXML performs hundreds of saving operations during training.

What does this PR do?

Alter the behaviors of lightning ModelCheckpoint. This adds new features to lightning ModelCheckpoint while simplify the arguments exposed to users/developers.

The ModelCheckpoint uses API of lightning 2.1.3. So the test will not be passed until Eleven upgrades the packages. No conflict now as I've moved the custom modelcheckpoint to a new file.

New features:

  1. Cache the best (top 1) model state_dict to RAM during training.

  2. Save the cached state_dict to disk at the end of training.

  3. This will be be contradicting. When save_weights_only=True, I would expect what will be saved is the state_dict only. However, lightning add hyperparameters as well as some other stuff to the outputs. I chose to only output state_dict. As we didn't use save_weights_only in LibMultiLabel, I think we shouldn't worry too much about this change.

Removed features:

  1. The new ModelCheckpoint does not save the last model. But the feature can be added as requested, without the issue of symlink since the behavior is under our control.
  2. No more top k. This is reasonable as we use save_top_k=1 in all cases.

Removed arguments

# I do not adopt these arguments, as their logics are tangled up, thus difficult to use. every_n_train_steps train_time_interval every_n_epochs

# no effect in the new ModelCheckpoint save_on_train_epoch_end

# force to ovewrite existing model: enable_version_counter

Test CLI & API (bash tests/autotest.sh)

Test APIs used by main.py.

Check API Document

If any new APIs are added, please check if the description of the APIs is added to API document.

Test quickstart & API (bash tests/docs/test_changed_document.sh)

If any APIs in quickstarts or tutorials are modified, please run this test to check if the current examples can run correctly after the modified APIs are released.

Eleven1Liu commented 10 months ago

related to #343

donglihe-hub commented 10 months ago

Closed as only related to experiment codes