wangermeng2021 / Scaled-YOLOv4-tensorflow2

A Tensorflow2.x implementation of Scaled-YOLOv4 as described in Scaled-YOLOv4: Scaling Cross Stage Partial Network
Apache License 2.0
47 stars 18 forks source link

Unable to load checkpoint #10

Closed aliencaocao closed 3 years ago

aliencaocao commented 3 years ago

I assumed that the way to load a saved checkpoint is the same as loading pretrained weight. However, when I try to load my own saved checkpoint and train again with the exact same data and exact same command, I got this error:

Traceback (most recent call last):
  File "train.py", line 310, in <module>
    main(args)
  File "train.py", line 151, in main
    pretrain_model.load_weights(args.p5_coco_pretrained_weights).expect_partial()
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\keras\engine\training.py", line 2205, in load_weights
    status = self._trackable_saver.restore(filepath, options)
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\training\tracking\util.py", line 1336, in restore
    base.CheckpointPosition(
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\training\tracking\base.py", line 253, in restore
    restore_ops = trackable._restore_from_checkpoint_position(self)  # pylint: disable=protected-access
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\training\tracking\base.py", line 972, in _restore_from_checkpoint_position
    current_position.checkpoint.restore_saveables(
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\training\tracking\util.py", line 307, in restore_saveables
    new_restore_ops = functional_saver.MultiDeviceSaver(
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\training\saving\functional_saver.py", line 345, in restore
    restore_ops = restore_fn()
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\training\saving\functional_saver.py", line 321, in restore_fn
    restore_ops.update(saver.restore(file_prefix, options))
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\training\saving\functional_saver.py", line 115, in restore
    restore_ops[saveable.name] = saveable.restore(
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\training\saving\saveable_object_util.py", line 131, in restore
    return resource_variable_ops.shape_safe_assign_variable_handle(
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\ops\resource_variable_ops.py", line 307, in shape_safe_assign_variable_handle
    shape.assert_is_compatible_with(value_tensor.shape)
  File "C:\Program Files\Python38\lib\site-packages\tensorflow\python\framework\tensor_shape.py", line 1134, in assert_is_compatible_with
    raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (340,) and (40,) are incompatible

The command I'm using: python train.py --epochs 200 --batch-size 4 --start-eval-epoch 0 --model-type p5 --use-pretrain True --dataset-type coco --dataset dataset/CV1/ --num-classes 5 --class-names CV1.names --coco-train-set train --coco-valid-set val --augment ssd_random_crop --p5-coco-pretrained-weights checkpoints/best_weight_p5_27_0.872

Let me know if you need the weight files to test. I can share

wangermeng2021 commented 3 years ago

Resuming training from checkpoints is not implemented yet, I will implement it soon. a temporary solution:

elif args.model_type == "p5": model = Yolov4(args, training=True) if args.use_pretrain: if len(os.listdir(os.path.dirname(args.p5_coco_pretrained_weights)))!=0: try: model.load_weights(args.p5_coco_pretrained_weights).expect_partial() print("Load {} checkpoints successfully!".format(args.model_type)) except: cur_num_classes = int(args.num_classes) args.num_classes = 80 pretrain_model = Yolov4(args, training=True) pretrain_model.load_weights(args.p5_coco_pretrained_weights).expect_partial() for layer in model.layers: if not layer.get_weights(): continue if 'yolov3_head' in layer.name: continue layer.set_weights(pretrain_model.get_layer(layer.name).get_weights()) args.num_classes = cur_num_classes print("Load {} weight successfully!".format(args.model_type)) else: raise ValueError("pretrained_weights directory is empty!")

wangermeng2021 commented 3 years ago

Resuming training from checkpoints is supported now.

aliencaocao commented 3 years ago

Okay thanks a lot!