(Check out my Medium post (https://goo.gl/Ussdr1) for more details of this project.)
The goal of this project is to develop models for Dstl Satellite Imagery Feature Detection contest on kaggle. The result scores 0.46 on the public test data set and 0.44 on the private test data set, would rank No. 7 out of 419 on the private leaderboard.
The training dataset includes 25 images, each with 20 channels (RGB band (3 channels) + A band (8 channels) + M band (8 channels) + P band (1 channel)), and the corresponding labels of objects. There are 10 types of overlapping objects labeled with contours (wkt
type of data), including 0. Buildings, 1. Misc, 2. Road, 3. Track, 4. Trees, 5. Crops, 6. Waterway, 7. Standing water, 8. Vehicle Large, 9. Vehicle Small.
This code converts the contours into masks, and then trains a pixel-wise binary classifier for each class of object. A U-net with batch norm developed in tensorflow is used as the classification model. A combination of cross entropy and soft Jaccard index, and the Adam optimizer are the loss function and the optimizer respectively. The following figures show examples of the training features and labels from one of the training examples. The code to generate these figures can be found in visualization.ipynb
.
This figure shows the statistics of percentage area for all classes of all the training data. (Note: on some images, the sum is over 100% because of overlap between classes.)
The model was developed and trained on a p2.xlarge
instance on AWS, which comes with the above hardware. At the beginning of the training for each class, all the 25 training images and the corresponding labels are loaded into RAM to avoid file I/O during the training, which can slow down the training. Therefore a large RAM (up to 50 GB) is required. The batch size and patches size of images in training and predictions are also customized for the ~11 GB memory on K80 GPU. These parameters should be adjusted according to your hardware.
To install all the requirements:
pip install -r requirements.txt
conda install -c https://conda.binstar.org/menpo opencv3
Download the data from contest website: https://www.kaggle.com/c/dstl-satellite-imagery-feature-detection/data
Put the data into the ./data/
folder.
The model is built to train a voxel-wise binary classifier for each of the 10 classes. Change the parameter class_type
to a number of 0-9
in ./hypes/hypes.json
to switch between classes. Run the following in the terminal to train a model for each class:
python train.py |& tee output.txt
All the print out is saved in output.txt
. All other logs for each training is saved at a folder in ./log_dir
, with a folder name of ./log_dir/month-day_hour-min_lossfunction
, including a TF checkpoint for every 1000-batch, a summary point for every 100-batch, and the hyper parameters for the training. The last TF checkpoint is used to generate predictions.
The final version of this code includes all the labeled data for training. You can set the test_names
in ./utils/train_utils.py
, and exclude them from the train_names
parameters to perform cross validation.
To monitor the training on the fly using tensorboard
, run the following code in terminal:
tensorboard --port 6006 --logdir summary_path --host 127.0.0.1
The following figures are examples of learning curves for the training of class 0, Bldg.
Modify the save_path
parameter of saver.restore()
in inference.py
to the path of the last checkpoint and change the class_type
in ./hypes/hypes.json
to the desired class type to generate predictions:
python inference.py |& tee test_output.txt
All the print out will be saved in test_output.txt
. The predictions will be saved in a CSV file ./submission/class_{class_type}.csv
.
To merge the prediction files of all classes (e.g. ./submission/class_0.csv
for class 0), run the following in terminal:
python merge_submission.py
A few errors of non-noded intersection
were encountered during my submission. This can be fixed by running python topology_exception.py
for each of the error. The script topology_exception.py
will create a hole around the point
, which can be found from the error message. You could also run the following in a python console:
repair_topology_exception('submission/valid_submission.csv',
precision=6,
image_id='6100_0_2',
n_class=4,
point= (0.0073326816112855523, -0.0069418340919529765),
side=1e-4)
The online evaluation returns a score of 0.44, which would rank No. 7 on the private leaderboard. The following figures show the comparison between the true label and predicted label.