jjkislele / SketchTriplet

A PyTorch Implementation for Sketch Triplet Networks
19 stars 4 forks source link
image-retrieval pytorch sbir sketch-based

Triplet_Loss_SBIR in Pytorch

This repo contains code implemented by Pytorch for the T.Bui et al's paper "Compact Descriptors for Sketch-based Image Retrieval using a Triplet loss Convolutional Neural Network"[Repo|Page].

The difference of the perposed network's architecture confuses me. In the paper, shown as Figure. 1, in each branch, the conv4 layer don't have ReLu node right behind it, though, in original codes conv4 does. I consult the original codes to build the net.

The network seems able to reproduce the results though, there is still much room for improvement IN MY CODE:

Dependency

How to Run

First, the pretrained model based on Flickr15k can be downloaded here. And the dataset Flickr15k can be downloaded here. Resized Flickr15k used for preview is provided here. 330sketches can be downloaded here and groundtruth is provided here. Canny edge detection procedure should be carried out to produce images' edge maps. Also, Flickr_15K_edge2 for images' edge maps is provided here.

Second, you should modify the root path to Flickr15k at ./train.py. The output of the model will be stored at ./out/flickr15k_yymmddHHMM/*.pth. The default root path is ../deep_hashing according to my case.

Third, run ./train.py to train the network. Use ./extract_feat_sketch.py and ./extract_feat_photo.py to extract features from sketches and photograps. The features will be stored at ./out_feat/flickr15k_yymmddHHMM/feat_sketch.npz and ./out_feat/flickr15k_yymmddHHMM/feat_photo.npz.

Last, use ./retrieval.py to gain results. The retrieval list will be stored at ./out_feat/flickr15k_yymmddHHMM/result. To be consistent with 330sketches query's file structure, results of every query are saved in group and sorted by similariy.

Code Structure

.
├── accessory
│   └── pr_curve.png
├── dataset
│   ├── 330sketches
│   ├── groundtruth
│   └── Flickr_15K_edge2
├── extract_feat_photo.py
├── extract_feat_sketch.py
├── flickr15k_dataset.py
├── model
│   ├── SketchTriplet_half_sharing.py
│   └── SketchTriplet.py
├── out
│   └── flickr15k_1904041458
│       ├── 500.pth
│       └── loss_and_accurary.txt
├── out_feat
│   └── flickr15k_1904041458
│       ├── feat_photo.npz
│       ├── feat_sketch.npz
│       └── result
├── README.md
├── retrieval.py
├── train.py
└── utils.py

Results - Flickr15k

We will train the network SketchTriplet on the dataset Flickr15k. The network takes an anchor (sketch input), positive (a photograph edgemap of same class as an anchor) and negative (photograph edgemaps of different class than an anchor) examples.

Some Parameters are shown as follows:

After 500 epochs of training, here are the pr curve we get for testing set.

Pr curve for testing

Also the loss curve during training is shown as follows.

Loss curve during training

Although it scores 67.7% mAP indicating just-so-so performance, the pr curve shows the model is over-fitting.

Todo List

Reference and Special Thanks

[1] adambielski's siamese-triplet repo

[2] weixu000's DSH-pytorch repo

[3] TuBui's Triplet_Loss_SBIR repo