neuronets / nobrainer

A framework for developing neural network models for 3D image processing.
Other
158 stars 45 forks source link

Get checkpoint based on create time is probably not a good idea #331

Open hvgazula opened 4 months ago

hvgazula commented 4 months ago

https://github.com/neuronets/nobrainer/blob/976691d685824fd4bba836498abea4184cffd798/nobrainer/processing/checkpoint.py#L57

What am I trying to do? Initialize from a previous checkpoint, to resume training over more epochs.

For example, the following snippet

try:
        bem = Segmentation.init_with_checkpoints(
        "unet",
        model_args=dict(batchnorm=True),
        checkpoint_filepath=checkpoint_filepath,
    )
except:
        bem = Segmentation(
            unet,
            model_args=dict(batchnorm=True),
            multi_gpu=True,
            checkpoint_filepath=checkpoint_filepath,
        )

should initialize from a checkpoint if the checkpoint_filepath exists. However, the getctime part conflicts with other folders created during training (could be predictions or other folders).

Solution:

hvgazula commented 4 months ago

In a nutshell, resumption from an existing checkpoint using API tools is still not working/clean. Works just fine with the tf inbuiltBackupAndRestore callback.

hvgazula commented 4 months ago

Also see https://github.com/neuronets/nobrainer/issues/332

hvgazula commented 4 months ago

appending / to checkpoint_filepath resolved this. see https://github.com/neuronets/nobrainer_training_scripts/commit/d5d1de07f6b8fde0d6471326f50c9bae6289aad1 🤦‍♂️

hvgazula commented 4 months ago

the getctime function only works if the checkpoint filepath has epoch in it's name..

For example: if checkpoint_filepath = f"output/{output_dirname}/nobrainer_ckpts/" + "{epoch:02d}" then the output (in addition to other folders) will look as follows:

Screenshot 2024-05-11 at 5 14 45 PM

Explanation of the folders:

  1. backup is the backandrestore callback (this will go away now)
  2. logs are tboard logs
  3. model_ckpts is me saving the model weights at the end of each epoch
  4. this is the modelcheckpoint (provided by the api)..but looks like i could do the same as step 3 with a few extra flags.
  5. predictions are plots or outputs at test time done right after each epoch.. this can be separated (if needed) if we have checkpoint from every epoch in step 4.

Summary:

  1. Setting a checkpoint_filepath with epoch and doing away with 3 will enable loading from checkpoints cleanly. Else, we may want to write improved logic for load when no folders are created for each epoch.
hvgazula commented 4 months ago

In hindsight, we should include 'BackupAndRestore' in addition to ModelCheckPoint, because the latter only saves a checkpoint at the end of each epoch. This will not be enough if the model passes through the entire data and fails just before writing whereas BackupAndRestore has a save_freq argument that can be taken advantage of.

hvgazula commented 4 months ago

ouch https://github.com/keras-team/tf-keras/issues/430. Looks like, we will have to stay put with ModelCheckPoint for now. 😞 This is because I intend to save the best model as well.