keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.47k stars 19.41k forks source link

`Checkpoint` was expecting detection_head to be a trackable object (an object derived from `Trackable`) #19425

Open sineeli opened 4 months ago

sineeli commented 4 months ago

Model Garden checkpoints decoders, detection and mask head as part of the checkpoint. I have seen the same issue with optimizer and it seems it has been resolved in release. Is there way that we can include these in tf.train.Checkpoint.

input_specs = keras.layers.InputSpec(shape=[None, None, None, 3])
    backbone = resnet.ResNet(model_id=50, input_specs=input_specs)
    decoder = fpn.FPN(
        min_level=3, max_level=7, input_specs=backbone.output_specs)
    rpn_head = dense_prediction_heads.RPNHead(
        min_level=3, max_level=7, num_anchors_per_location=3)
    detection_head = instance_heads.DetectionHead(num_classes=2)
    roi_generator_obj = roi_generator.MultilevelROIGenerator()
    roi_sampler_obj = roi_sampler.ROISampler()
    roi_aligner_obj = roi_aligner.MultilevelROIAligner()
    detection_generator_obj = detection_generator.DetectionGenerator()
    if include_mask:
      mask_head = instance_heads.MaskHead(num_classes=2, upsample_factor=2)
      mask_sampler_obj = mask_sampler.MaskSampler(
          mask_target_size=28, num_sampled_masks=1)
      mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(crop_size=14)
    else:
      mask_head = None
      mask_sampler_obj = None
      mask_roi_aligner_obj = None
    model = maskrcnn_model.MaskRCNNModel(
        backbone,
        decoder,
        rpn_head,
        detection_head,
        roi_generator_obj,
        roi_sampler_obj,
        roi_aligner_obj,
        detection_generator_obj,
        mask_head,
        mask_sampler_obj,
        mask_roi_aligner_obj,
        min_level=3,
        max_level=7,
        num_scales=3,
        aspect_ratios=[1.0],
        anchor_size=3)
    expect_checkpoint_items = dict(
        backbone=backbone,
        decoder=decoder,
        rpn_head=rpn_head,
        detection_head=[detection_head])
    if include_mask:
      expect_checkpoint_items['mask_head'] = mask_head
    self.assertAllEqual(expect_checkpoint_items, model.checkpoint_items)

    # Test save and load checkpoints.
    ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
    save_dir = self.create_tempdir().full_path
    ckpt.save(os.path.join(save_dir, 'ckpt'))

    partial_ckpt = tf.train.Checkpoint(backbone=backbone)
    partial_ckpt.read(tf.train.latest_checkpoint(
        save_dir)).expect_partial().assert_existing_objects_matched()

    if include_mask:
      partial_ckpt_mask = tf.train.Checkpoint(
          backbone=backbone, mask_head=mask_head)
      partial_ckpt_mask.restore(tf.train.latest_checkpoint(
          save_dir)).expect_partial().assert_existing_objects_matched()

Error:

Traceback (most recent call last):
  File "/Users/sineeli/Documents/tfm-keras3/env/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
    return test_method(self, *testcase_params)
  File "/Users/sineeli/Documents/tfm-keras3/models/official/vision/modeling/maskrcnn_model_test.py", line 395, in test_checkpoint
    ckpt = tf.train.Checkpoint(model=model, **model.checkpoint_items)
  File "/Users/sineeli/Documents/tfm-keras3/env/lib/python3.10/site-packages/tensorflow/python/checkpoint/checkpoint.py", line 2225, in __init__
    _assert_trackable(converted_v, k)
  File "/Users/sineeli/Documents/tfm-keras3/env/lib/python3.10/site-packages/tensorflow/python/checkpoint/checkpoint.py", line 1573, in _assert_trackable
    raise ValueError(
ValueError: `Checkpoint` was expecting detection_head to be a trackable object (an object derived from `Trackable`), got [<DetectionHead name=detection_head_33, built=False>]. If you believe this object should be trackable (i.e. it is part of the TensorFlow Python API and manages state), please open an issue.
sachinprasadhs commented 4 months ago

Similar issue with the Optimizer was resolved in the latest Keras version here https://github.com/keras-team/keras/issues/19321

sineeli commented 4 months ago

Please find the attached gist for reference

grasskin commented 4 months ago

Could you try rerunning with !pip install keras-nightly and !pip install tf-nightly to make sure this is still reproducible after #19321?

sineeli commented 4 months ago

Could you try rerunning with !pip install keras-nightly and !pip install tf-nightly to make sure this is still reproducible after #19321?

The fix in #1932 is for optimizer, which is already got fixed and present in recent release if I am not wrong. The issue above is with checkpointing other extra modules. The errors seems to persist even after nightly installation. Here is the gist.

Thanks

fchollet commented 4 months ago

The error message says:

ValueError: Checkpoint was expecting detection_head to be a trackable object (an object derived from Trackable), got []

What is the type of DetectionHead here? Is it a Keras layer? All Keras layers are Trackables. Is the issue that it is wrapped in a list?

Could this be an issue with the Model Garden package?

sineeli commented 4 months ago

From code base I can see that Mask-RCNN model expecting detection head as tf.keras.layer.Layer or List[tf.keras.layers.Layer], so we cannot track list of heads.

DetectionHead is wrapped around a list while building the Mask-RCNN Model. So passing with or without List results in list wrapped keras layer.

I have checked same with Keras2(tf-keras) version it seems to be working in that version: gist

So to overcome this issue, is there way that I can change from Model Garden side to make them trackable.

Thanks

fchollet commented 4 months ago

So to overcome this issue, is there way that I can change from Model Garden side to make them trackable.

You could fork the Model Garden repo and open a PR?

sineeli commented 4 months ago

I apologize for any confusion in my previous message. I'm currently refactoring the code within the model garden itself to be compatible with Keras3.

My question is: when checkpointing a Keras model, is there a way to keep track of list of Keras layers in Keras3 without going into above error? As the same works in Keras 2

List of heads(layers) are getting used for cascade r-cnn model.

Hope this clairfies.

Thanks