deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
3.97k stars 630 forks source link

Mask-Rcnn transfer learning example #170

Open BernhardGlueck opened 3 years ago

BernhardGlueck commented 3 years ago

Any chance of getting a simple example (even just ad hoc ) on how to train mask rcnn via transfer learning on a custom Coco formatted dataset ?

We are currently evaluating if we need to write our main application in python because of this requirement or if we can switch to JVM/Kotlin ( which we would prefer ) and I would really like to do a small POC on this.

lanking520 commented 3 years ago

Yes there should be a way to achieve that, which DL framework you are looking for? Currently we support MXNet transfer learning and experimental PyTorch Transfer Learning. We have MaskRCNN model in MXNet model zoo. PyTorch one is on the way.

BernhardGlueck commented 3 years ago

MxNet ist totally fine ....

lanking520 commented 3 years ago

So here is the steps.

1) Get the MaskRCNN pretrained model: http://docs.djl.ai/mxnet/mxnet-model-zoo/index.html, you can choose one of the backbone. Try to get it run with our instance segmentation example http://docs.djl.ai/mxnet/mxnet-model-zoo/index.html.

2) Prepare your dataset, we do have coco dataset in our DataSet (https://github.com/awslabs/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/CocoDetection.java). You can implement your custom dataset from it

3) Train the model using the dataset. You can just create a Trainer from the model. Since the model is came from MXNet, you can follow the similar steps for the original model training: https://gluon-cv.mxnet.io/build/examples_instance/train_mask_rcnn_coco.html. In DJL, you can spawn a Trainer from the Pretrained Model.

The last part is pretty difficult since MaskRCNN model itself is complex. https://github.com/awslabs/djl/blob/master/examples/src/main/java/ai/djl/examples/training/transferlearning/TrainResnetWithCifar10.java#L133-L156 You may need to do something similar to what we did for Resnet if the output classes is not the same as the original dataset.

My recommendation here to save the time is to use the GluonCV (python) and change the layers to fit your use cases. Once this is done. You can convert the model into MXNet Symbol and get it trained in Java with the DataSet you implemented.

BernhardGlueck commented 3 years ago

That sounds like a plan :-) However i think i will have to with modifying the network architecture directly in java... since we plan to train on quite a few datasets with different classes so going the gluon -> mxnet -> dlj route would have to be automated and that in itself would be quite a lot of work.