experiencor / keras-yolo3

Training and Detecting Objects with YOLO3
MIT License
1.6k stars 861 forks source link

How to do transfer learning? #241

Open ZenoLM opened 4 years ago

ZenoLM commented 4 years ago

After having trained the l network on a custom dataset and having saved the model, I want to train the same model on another dataset. I've tried editing the "pre-trained weights" parameter in the configuration file, replacing the yolov3.weights with the path to the trained model, but it doesn't seem to access it when starting the training. Am I missing something, how can I train the model on different datasets subsequently?

animikhaich commented 4 years ago

There are 2 types of files and weights in the given implementation.

  1. A keras implementation of YOLOv3 using which you can train and run inference on a custom dataset. This uses keras weights (.h5) to save the model. This part uses the config.json as an input of parameters. The files associated with this implementation are train.py and predict.py.
  2. The other implementation of YOLOv3 here uses keras to build and predict the architecture but loads the weights from the native yolov3.weights, which in other words, is the darknet weights in darknet format. The file associated with this is: yolo3_one_file_to_detect_them_all.py. This can only be used to run a demo inference on a given image and cannot be used to train using the code available in this repo.

With that background information, coming back to your doubt, the author of this repo has already provided with the keras version of the default yolov3.weights - "backend.h5" and the same can be downloaded HERE.

If you do not specify any existing pre-trained keras weights (.h5) in the config.json, the program is designed to automatically use the coco-trained default "backend.h5" weights during the start of the training. Hence, even though the initial few layers are not frozen, the weight initialization is done using the default yolov3.weights.

If you want to strictly use transfer learning, you may modify the code to freeze the first few CNN layers. The architecture is defined in the yolo.py file.