hhk7734 / tensorflow-yolov4

YOLOv4 Implemented in Tensorflow 2.
MIT License
136 stars 75 forks source link

InvalidArgumentError During Training #58

Closed LIHANG-HONG closed 3 years ago

LIHANG-HONG commented 3 years ago

Hi, thank you for your great library. I met InvalidArgumentError using yolo4 to train a custom model. Here is my code.

from yolov4.tf import SaveWeightsCallback, YOLOv4
from PIL import Image
import numpy as np
from tensorflow.keras import callbacks, optimizers

yolo = YOLOv4(tiny = True)
yolo.classes='classes.names'
train_data_set = yolo.load_dataset(
    '2007_train.txt',
    image_path_prefix='VOCdevkit/VOC2007/JPEGImages',
    label_smoothing=0.02
)
val_data_set = yolo.load_dataset(
    "2007_test.txt",
    image_path_prefix='VOCdevkit/VOC2007/JPEGImages',
    training=False
)
yolo.input_size=416
yolo.batch_size = 1
yolo.make_model()
yolo.load_weights(
    "yolov4-tiny.conv.29",
    weights_type="yolo"
)
epochs = 400
lr = 1e-4

#optimizer = optimizers.Adam(learning_rate=lr)
optimizer = AdaBeliefOptimizer(epsilon=1e-12, rectify=False, print_change_log = False)
yolo.compile(optimizer=optimizer, loss_iou_type="ciou")

def lr_scheduler(epoch):
    if epoch < int(epochs * 0.5):
        return lr
    if epoch < int(epochs * 0.8):
        return lr * 0.5
    if epoch < int(epochs * 0.9):
        return lr * 0.1
    return lr * 0.01

_callbacks = [
    callbacks.LearningRateScheduler(lr_scheduler),
    callbacks.TerminateOnNaN(),
    callbacks.TensorBoard(
        log_dir="yolo_logs",
    ),
    SaveWeightsCallback(
        yolo=yolo, dir_path="yolo_logs",
        weights_type="yolo", epoch_per_save=10
    ),
]
yolo.fit(
    train_data_set,
    epochs=epochs,
    callbacks=_callbacks,
#    validation_data=val_data_set,
#    validation_steps=50,
#    validation_freq=5,
    steps_per_epoch=50,
)

here is the error I met when executing yolo.fit() on jupyter

Epoch 1/400
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-21-73556c6a70df> in <module>()
      6 #    validation_steps=50,
      7 #    validation_freq=5,
----> 8     steps_per_epoch=50,
      9 )

7 frames
/usr/local/lib/python3.6/dist-packages/yolov4/tf/__init__.py in fit(self, data_set, epochs, verbose, callbacks, validation_data, initial_epoch, steps_per_epoch, validation_steps, validation_freq, **kwargs)
    280             validation_steps=validation_steps,
    281             validation_freq=validation_freq,
--> 282             **kwargs
    283         )
    284 

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
   1098                 _r=1):
   1099               callbacks.on_train_batch_begin(step)
-> 1100               tmp_logs = self.train_function(iterator)
   1101               if data_handler.should_sync:
   1102                 context.async_wait()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
    886         # Lifting succeeded, so variables are initialized and we can run the
    887         # stateless function.
--> 888         return self._stateless_fn(*args, **kwds)
    889     else:
    890       _, _, _, filtered_flat_args = \

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in __call__(self, *args, **kwargs)
   2941        filtered_flat_args) = self._maybe_define_function(args, kwargs)
   2942     return graph_function._call_flat(
-> 2943         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
   2944 
   2945   @property

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1917       # No tape is watching; skip to running the function.
   1918       return self._build_call_outputs(self._inference_function.call(
-> 1919           ctx, args, cancellation_manager=cancellation_manager))
   1920     forward_backward = self._select_forward_and_backward_functions(
   1921         args,

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py in call(self, ctx, args, cancellation_manager)
    558               inputs=args,
    559               attrs=attrs,
--> 560               ctx=ctx)
    561         else:
    562           outputs = execute.execute_with_cancellation(

/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

InvalidArgumentError:  Input to reshape is a tensor with 970368 values, but the requested shape requires a multiple of 14196
     [[node YOLOv4Loss/Reshape (defined at /usr/local/lib/python3.6/dist-packages/yolov4/tf/train.py:64) ]] [Op:__inference_train_function_74742]

Errors may have originated from an input operation.
Input Source operations connected to node YOLOv4Loss/Reshape:
 IteratorGetNext (defined at /usr/local/lib/python3.6/dist-packages/yolov4/tf/__init__.py:282)

Function call stack:
train_function

the version is as below: yolov4 2.1.0 tensorflow 2.4.1 Keras 2.4.3

Could you please help me figure out the reason of this error?

LIHANG-HONG commented 3 years ago

Hi, I have solved the problem by changing the order of the code. from

train_data_set = yolo.load_dataset(
    '2007_train.txt',
    image_path_prefix='VOCdevkit/VOC2007/JPEGImages',
    label_smoothing=0.02
)
yolo.input_size=416
yolo.batch_size = 1

to

yolo.input_size=416
yolo.batch_size = 1
train_data_set = yolo.load_dataset(
    '2007_train.txt',
    image_path_prefix='VOCdevkit/VOC2007/JPEGImages',
    label_smoothing=0.02
)

but I still dont know why.....

hhk7734 commented 3 years ago

https://github.com/hhk7734/tensorflow-yolov4/blob/4f75970d02b60960257f9f7baa091e0e6874a890/py_src/yolov4/tf/__init__.py#L218-L238

yolo.load_dataset internally uses yolo.input_size and yolo.batch_size.