zzh8829 / yolov3-tf2

YoloV3 Implemented in Tensorflow 2.0
MIT License
2.51k stars 909 forks source link

Using custom Darknet network with rectangular input #105

Open kylemcdonald opened 4 years ago

kylemcdonald commented 4 years ago

Hello, and thanks for your work.

I would like to use a network I trained using https://github.com/AlexeyAB/darknet/ It accepts 576x320x3 input and predicts 4 classes.

I started by converting the network:

$ git clone https://github.com/zzh8829/yolov3-tf2.git && cd yolov3-tf2
$ conda env create -f conda-cpu.yml
$ conda activate yolov3-tf2-cpu
$ python convert.py --weights 576x320.weights --output 576x320.tf --num_classes 4
...
I1120 19:21:25.908616 140531829884736 convert.py:22] weights loaded
I1120 19:21:26.194279 140531829884736 convert.py:26] sanity check passed
I1120 19:21:26.551275 140531829884736 convert.py:29] weights saved

Then I ran my code (note: I had to hardcode yolo_iou_threshold and yolo_iou_score for this code to run):

import numpy as np
from yolov3_tf2.models import YoloV3
yolo = YoloV3(classes=4)
yolo.load_weights('576x320.tf')
x = np.random.random((1,320,576,3)).astype(np.float32)
yolo.predict(x)

But I get an error. I also tried with size=None and got the same error. I think it's might be related to some way that the network is created under an assumption equal width and height? But I can't find where. How can I fix this? Note that I don't get the error on the 320x320 network. I can also pass other square sizes to the network like 160x160 or 640x640 without problems. I've pasted the error below. Thank you!

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-38-2898b7b95dd6> in <module>
      4 yolo.load_weights('576x320.tf')
      5 x = np.random.random((1, 320, 576, 3)).astype(np.float32)
----> 6 yolo.predict(x)

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
    907         max_queue_size=max_queue_size,
    908         workers=workers,
--> 909         use_multiprocessing=use_multiprocessing)
    910 
    911   def reset_metrics(self):

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py in predict(self, model, x, batch_size, verbose, steps, callbacks, **kwargs)
    720         verbose=verbose,
    721         steps=steps,
--> 722         callbacks=callbacks)

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_arrays.py in model_iteration(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)
    391 
    392         # Get outputs.
--> 393         batch_outs = f(ins_batch)
    394         if not isinstance(batch_outs, list):
    395           batch_outs = [batch_outs]

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py in __call__(self, inputs)
   3738         value = math_ops.cast(value, tensor.dtype)
   3739       converted_inputs.append(value)
-> 3740     outputs = self._graph_fn(*converted_inputs)
   3741 
   3742     # EagerTensor.numpy() will often make a copy to ensure memory safety.

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   1079       TypeError: For invalid positional/keyword argument combinations.
   1080     """
-> 1081     return self._call_impl(args, kwargs)
   1082 
   1083   def _call_impl(self, args, kwargs, cancellation_manager=None):

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _call_impl(self, args, kwargs, cancellation_manager)
   1119       raise TypeError("Keyword arguments {} unknown. Expected {}.".format(
   1120           list(kwargs.keys()), list(self._arg_keywords)))
-> 1121     return self._call_flat(args, self.captured_inputs, cancellation_manager)
   1122 
   1123   def _filtered_call(self, args, kwargs):

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1222     if executing_eagerly:
   1223       flat_outputs = forward_function.call(
-> 1224           ctx, args, cancellation_manager=cancellation_manager)
   1225     else:
   1226       gradient_name = self._delayed_rewrite_functions.register()

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    509               inputs=args,
    510               attrs=("executor_type", executor_type, "config_proto", config),
--> 511               ctx=ctx)
    512         else:
    513           outputs = execute.execute_with_cancellation(

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     65     else:
     66       message = e.message
---> 67     six.raise_from(core._status_to_exception(e.code, message), None)
     68   except TypeError as e:
     69     keras_symbolic_tensors = [

~/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  Incompatible shapes: [10,10,1,2] vs. [1,10,18,3,2]
     [[node yolo_boxes_0_7/add (defined at /home/kyle/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1751) ]]
     [[yolo_nms_7/Reshape_9/_310]]
  (1) Invalid argument:  Incompatible shapes: [10,10,1,2] vs. [1,10,18,3,2]
     [[node yolo_boxes_0_7/add (defined at /home/kyle/anaconda3/envs/tf2-gpu/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1751) ]]
0 successful operations.
0 derived errors ignored. [Op:__inference_keras_scratch_graph_135939]

Function call stack:
keras_scratch_graph -> keras_scratch_graph

Updates: it looks like the code that needs to be changed is inside def yolo_boxes(). I had to account for the fact that grid_size is different in each axis. When porting models from Darknet, this also means scaling the anchors by the aspect ratio. Here's what I did to get it working:

--- a/yolov3_tf2/models.py
+++ b/yolov3_tf2/models.py
@@ -148,7 +150,7 @@ def YoloOutput(filters, anchors, classes, name=None):

 def yolo_boxes(pred, anchors, classes):
     # pred: (batch_size, grid, grid, anchors, (x, y, w, h, obj, ...classes))
-    grid_size = tf.shape(pred)[1]
+    grid_size = tf.cast(tf.shape(pred)[1:3][::-1], tf.float32)
     box_xy, box_wh, objectness, class_probs = tf.split(
         pred, (2, 2, 1, classes), axis=-1)

@@ -158,12 +160,11 @@ def yolo_boxes(pred, anchors, classes):
     pred_box = tf.concat((box_xy, box_wh), axis=-1)  # original xywh for loss

     # !!! grid[x][y] == (y, x)
-    grid = tf.meshgrid(tf.range(grid_size), tf.range(grid_size))
+    grid = tf.meshgrid(tf.range(grid_size[0]), tf.range(grid_size[1]))
     grid = tf.expand_dims(tf.stack(grid, axis=-1), axis=2)  # [gx, gy, 1, 2]

-    box_xy = (box_xy + tf.cast(grid, tf.float32)) / \
-        tf.cast(grid_size, tf.float32)
-    box_wh = tf.exp(box_wh) * anchors
+    box_xy = (box_xy + tf.cast(grid, tf.float32)) / grid_size
+    box_wh = tf.exp(box_wh) * anchors * (grid_size[0] / grid_size)

     box_x1y1 = box_xy - box_wh / 2
     box_x2y2 = box_xy + box_wh / 2

Note that I copied my anchors from the Darknet yolov3.cfg file and used them like this:

anchors = np.array([5,5,59,3,9,45,59,9,61,19,21,55,44,31,40,54,63,37]).reshape(-1,2)/576
yolo = YoloV3(classes=4, anchors=anchors)
nicolefinnie commented 4 years ago

@kylemcdonald Thanks for sharing. Also we have to tailor the yolo_loss function for the rectangle loss as follows:

def yolo_loss(y_true, y_pred):
    # 3. inverting the pred box equations
    grid_size = tf.cast(tf.shape(y_true)[1:3][::-1], tf.float32)
    grid = tf.meshgrid(tf.range(grid_size[0]), tf.range(grid_size[1]))
    # (Optional) if you have normalized your anchors by `height/width`, this step is not necessary
    true_wh = tf.math.log(true_wh / anchors * grid / grid_size[0])

However, if a priori anchors have been normalized by height/width, we don't need to adjust the ratio of anchors, e.g. if an anchor (5, 5) is normalized proportionally to (5/576, 5/320), in this case, the following line in yolo_boxes() is not necessary.

def yolo_boxes(pred, anchors, classes):
    # (Optional) if you have normalized your anchors by `height/width`, this step is not necessary
    box_wh = tf.exp(box_wh) * anchors * (grid_size[0] / grid_size)
cosminacho commented 4 years ago

Hi, when you train on rectangular images, don't you also need to modify the code here (dataset.py):

def transform_targets(y_train, anchors, anchor_masks, width, height):
    y_outs = []

    grid_size_w = width // 32
    grid_size_h = height // 32

    --------------------------

    for anchor_idxs in anchor_masks:
        y_outs.append(transform_targets_for_output(
            y_train, grid_size_w, grid_size_h, anchor_idxs))
        grid_size_w *= 2
        grid_size_h *= 2

    return tuple(y_outs)

and here:

@tf.function  # TODO: check here if it is ok!!!
def transform_targets_for_output(y_true, grid_size_w, grid_size_h, anchor_idxs):
    # y_true: (N, boxes, (x1, y1, x2, y2, class, best_anchor))
    N = tf.shape(y_true)[0]

    # y_true_out: (N, grid, grid, anchors, [x, y, w, h, obj, class])
    y_true_out = tf.zeros(
        (N, grid_size_h, grid_size_w, tf.shape(anchor_idxs)[0], 6))

   ---------------------------------------------------------

                anchor_idx = tf.cast(tf.where(anchor_eq), tf.int32)
                grid_xy = tf.cast(box_xy // ((1/grid_size_w),
                                             (1/grid_size_h)), tf.int32)

And finally, you will have to adjust your model to take a custom input size of (width x height) in models.py:

def YoloV3(width=None, height=None, channels=3, anchors=yolo_anchors,
           masks=yolo_anchor_masks, classes=80, training=False):
    x = inputs = Input([height, width, channels], name='input')