Open hyichao opened 7 years ago
@hyichao Hello, I want to use ZF instead of VGG too, but the training doesn't converge, and the loss fluctuates very greatly. Can you share your modification of model, or give me some advices?Thanks!
Hi @ChristineRYY, Yes, I did found something useful.
def ZFNetBody(net, from_layer, need_fc=True, fully_conv=False, reduced=False,
dilated=False, nopool=False, dropout=True, freeze_layers=[]):
kwargs = {
'param': [dict(lr_mult=1, decay_mult=1), dict(lr_mult=2, decay_mult=0)],
'weight_filler': dict(type='xavier'),
'bias_filler': dict(type='constant', value=0)}
assert from_layer in net.keys()
net.conv1 = L.Convolution(net[from_layer], num_output=96, pad=3, kernel_size=7,stride=2, **kwargs)
net.relu1 = L.ReLU(net.conv1, in_place=True)
net.pool1 = L.Pooling(net.relu1, pool=P.Pooling.MAX, kernel_size=3, stride=2, pad=1)
# net.norm1 = L.LRN(net.pool1, lrn_param=dict(local_size=3,alpha=0.00005,beta=0.75,norm_region=1,engine=1))
net.conv2 = L.Convolution(net.pool1, num_output=256, pad=2, kernel_size=5,stride=2, **kwargs)
net.relu2 = L.ReLU(net.conv2, in_place=True)
net.pool2 = L.Pooling(net.relu2, pool=P.Pooling.MAX, kernel_size=3, stride=2, pad=1)
# net.norm2 = L.LRN(net.pool2,lrn_param=dict(local_size=3,alpha=0.00005,beta=0.75,norm_region=1,engine=1))
net.conv3 = L.Convolution(net.pool2, num_output=384, pad=1, kernel_size=3,stride=1, **kwargs)
net.relu3 = L.ReLU(net.conv3, in_place=True)
net.conv4 = L.Convolution(net.relu3, num_output=384, pad=1, kernel_size=3,stride=1, **kwargs)
net.relu4 = L.ReLU(net.conv4, in_place=True)
net.conv5 = L.Convolution(net.relu4, num_output=256, pad=1, kernel_size=3,stride=1, **kwargs)
net.relu5 = L.ReLU(net.conv5, in_place=True)
# Update freeze layers.
kwargs['param'] = [dict(lr_mult=0, decay_mult=0), dict(lr_mult=0, decay_mult=0)]
layers = net.keys()
for freeze_layer in freeze_layers:
if freeze_layer in layers:
net.update(freeze_layer, kwargs)
return net
I remove the LRN layer since it is useless....
ssd_pascal.py
In your case, "but the training doesn't converge, and the loss fluctuates very greatly", I suggest you add a scale factor in data transform parameter.train_transform_param = {
'mirror': True,
'scale': 0.0078125,
'mean_value': [104, 117, 123],
'resize_param': {
'prob': 1,
'resize_mode': P.Resize.WARP,
'height': resize_height,
'width': resize_width,
'interp_mode': [
P.Resize.LINEAR,
P.Resize.AREA,
P.Resize.NEAREST,
P.Resize.CUBIC,
P.Resize.LANCZOS4,
],
},
'emit_constraint': {
'emit_type': caffe_pb2.EmitConstraint.CENTER,
}
}
test_transform_param = {
'scale': 0.0078125,
'mean_value': [104, 117, 123],
'resize_param': {
'prob': 1,
'resize_mode': P.Resize.WARP,
'height': resize_height,
'width': resize_width,
'interp_mode': [P.Resize.LINEAR],
},
}
while 0.0078125 is the result of 1/128.0, which make the input from (-128,127) to (-1, +1).
ssd_detect.cpp
. You MUST add a line of code to scale the input image before it goes into network. I do it in the Preprocess
function like the following. cv::Mat sample_normalized;
cv::subtract(sample_float, mean_, sample_normalized);
sample_normalized*=0.0078125;// during training the input image is scaled by 1/128
Therefore the input is the same as training stage and you will see a relatively good result in testing stage.
Hope this helps.
@hyichao Thanks for the information. I assume ZFNet is pretrained with image in (-128, 128)? If you add scale 1/128 during finetuning, then all the pretrained parameters essentially becomes 'invalid' (because the input range for each layer is different). I am curious what is the mAP you can get with this modification?
@weiliu89 I agree that using scale will cause the input range change, but I don't think this will make bad influence, as the pre-train will still produce 'correct' response but the value is smaller..
The prior box generation is almost the same. Since the network is changed, I have to change the mbox_source_layer
into
# conv5 ==> 20 x 20
# conv6_2 ==> 10 x 10
# conv7_2 ==> 5 x 5
# conv8_2 ==> 3 x 3
# pool6 ==> 1 x 1
mbox_source_layers = ['conv5', 'conv6_2', 'conv7_2', 'conv8_2', 'pool6']
Unfortunately the mAP's not good enough in PASCAL --- around 0.50 and that's why I open this issue :( But when I apply this method to simple cases, I do use a small network (ZFNet or even smaller) to make sure it is fast and accurate, and it works fine.
@hyichao I use your said approach to generate train and test prototxt, use VOC0712 dataset and finetune on ZF.caffemodel. When lm_lr>0.0001, it always get loss=1.#QNAN, set lm_lr=0.0001, it doesn't converge even under 10000 iterations, and the detection_eval is about 0.00*. Below is my prototxt, is anything wrong with it?
name: "ZF_VOC0712_SSD_300x300_train" layer { name: "data" type: "AnnotatedData" top: "data" top: "label" include { phase: TRAIN } transform_param { mirror: true mean_value: 104 mean_value: 117 mean_value: 123 force_color: true resize_param { prob: 1 resize_mode: WARP height: 300 width: 300 interp_mode: LINEAR interp_mode: AREA interp_mode: NEAREST interp_mode: CUBIC interp_mode: LANCZOS4 } emit_constraint { emit_type: CENTER } } data_param { source: "E:/imglib/voc2007/VOC0712_trainval_lmdb" batch_size: 32 backend: LMDB } annotated_data_param { batch_sampler { max_sample: 1 max_trials: 1 } batch_sampler { sampler { min_scale: 0.3 max_scale: 1.0 min_aspect_ratio: 0.5 max_aspect_ratio: 2.0 } sample_constraint { min_jaccard_overlap: 0.1 } max_sample: 1 max_trials: 50 } batch_sampler { sampler { min_scale: 0.3 max_scale: 1.0 min_aspect_ratio: 0.5 max_aspect_ratio: 2.0 } sample_constraint { min_jaccard_overlap: 0.3 } max_sample: 1 max_trials: 50 } batch_sampler { sampler { min_scale: 0.3 max_scale: 1.0 min_aspect_ratio: 0.5 max_aspect_ratio: 2.0 } sample_constraint { min_jaccard_overlap: 0.5 } max_sample: 1 max_trials: 50 } batch_sampler { sampler { min_scale: 0.3 max_scale: 1.0 min_aspect_ratio: 0.5 max_aspect_ratio: 2.0 } sample_constraint { min_jaccard_overlap: 0.7 } max_sample: 1 max_trials: 50 } batch_sampler { sampler { min_scale: 0.3 max_scale: 1.0 min_aspect_ratio: 0.5 max_aspect_ratio: 2.0 } sample_constraint { min_jaccard_overlap: 0.9 } max_sample: 1 max_trials: 50 } batch_sampler { sampler { min_scale: 0.3 max_scale: 1.0 min_aspect_ratio: 0.5 max_aspect_ratio: 2.0 } sample_constraint { max_jaccard_overlap: 1.0 } max_sample: 1 max_trials: 50 } label_map_file: "E:/imglib/voc2007/labelmap_voc.prototxt" } } layer { name: "conv1" type: "Convolution" bottom: "data" top: "conv1" param { lr_mult: 0 decay_mult: 0 } param { lr_mult: 0 decay_mult: 0 } convolution_param { num_output: 96 pad: 3 kernel_size: 7 stride: 2 weight_filler { type: "xavier" } bias_filler { type: "constant" value: 0 } } } layer { name: "relu1" type: "ReLU" bottom: "conv1" top: "conv1" } layer { name: "pool1" type: "Pooling" bottom: "conv1" top: "pool1" pooling_param { pool: MAX kernel_size: 3 stride: 2 pad: 1 } } layer { name: "conv2" type: "Convolution" bottom: "pool1" top: "conv2" param { lr_mult: 0 decay_mult: 0 } param { lr_mult: 0 decay_mult: 0 } convolution_param { num_output: 256 pad: 2 kernel_size: 5 stride: 2 weight_filler { type: "xavier" } bias_filler { type: "constant" value: 0 } } } layer { name: "relu2" type: "ReLU" bottom: "conv2" top: "conv2" } layer { name: "pool2" type: "Pooling" bottom: "conv2" top: "pool2" pooling_param { pool: MAX kernel_size: 3 stride: 2 pad: 1 } } layer { name: "conv3" type: "Convolution" bottom: "pool2" top: "conv3" param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0 } convolution_param { num_output: 384 pad: 1 kernel_size: 3 stride: 1 weight_filler { type: "xavier" } bias_filler { type: "constant" value: 0 } } } layer { name: "relu3" type: "ReLU" bottom: "conv3" top: "conv3" } layer { name: "conv4" type: "Convolution" bottom: "conv3" top: "conv4" param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0 } convolution_param { num_output: 384 pad: 1 kernel_size: 3 stride: 1 weight_filler { type: "xavier" } bias_filler { type: "constant" value: 0 } } } layer { name: "relu4" type: "ReLU" bottom: "conv4" top: "conv4" } layer { name: "conv5" type: "Convolution" bottom: "conv4" top: "conv5" param { lr_mult: 1 decay_mult: 1 } param { lr_mult: 2 decay_mult: 0 } convolution_param { num_output: 256 pad: 1 kernel_size: 3 stride: 1 weight_filler { type: "xavier" } bias_filler { type: "constant" value: 0 } } } layer { name: "relu5" type: "ReLU" bottom: "conv5" top: "conv5" }
ELSE is the same to vgg based model...
@ChristineRYY I did add a scale in train_transform_param
, which will generate a train.prototxt like below.
name: "ZF_VOC0712_SSD_300x300_train"
layer {
name: "data"
type: "AnnotatedData"
top: "data"
top: "label"
include {
phase: TRAIN
}
transform_param {
scale: 0.0078125
mirror: true
mean_value: 104
mean_value: 117
mean_value: 123
resize_param {
prob: 1
resize_mode: WARP
height: 300
width: 300
interp_mode: LINEAR
interp_mode: AREA
interp_mode: NEAREST
interp_mode: CUBIC
interp_mode: LANCZOS4
}
emit_constraint {
emit_type: CENTER
}
}
data_param {
source: "examples/VOC0712/VOC0712_trainval_lmdb"
batch_size: 16
backend: LMDB
}
annotated_data_param {
# the followings are samplers and layer definitions..
Notice that there is a
scale: 0.0078125
which I did not see in yours.
Recently I think about this scale
factor again and I found that applying mean_value
and scale
makes the same effect with adding BatchNorm layer right after data layer, as they transform input data into a distribution of N(0,1). I'll try to go deeper and do a little bit more research on the difference between scale
and batch norm
@hyichao @ChristineRYY If you are using conv5
or even lower layer as mbox_source_layers
, have you set normalizations
as done here? ZFNet has same problem as VGGNet because they are not trained with batch normalization. Thus the scale of the layers in the network are large.
@hyichao If you set scale
in the input image layer, it is almost equal to train from scratch, I think. Because it is completely different from what ZFNet was pretrained on.
@hyichao I'm very sorry to tell you, my training doesn't converge after I add scale: 0.0078125.
@ChristineRYY Have you pulled the latest code? And have you tried to make clean
and compile the whole thing again? It should converge. I am training a ZFNet, and will provide some reference script for it. I would expect it to be around 10 points lower than VGGNet. We will see tomorrow.
@ChristineRYY @hyichao You can check example/ssd/ssd_pascal_zf.py
on how to train with ZFNet. It can achieve about 62.* mAP with the script.
@weiliu89 Thank you very much !!!
@weiliu89 Thanks a lot and I did reached 62 by this script. But I still have several Qs hopes to be answered.
Conv1
and Conv2
, which make contribution to convergence?Conv2
as one of the mbox_source_layer
?
Thanks again.I didn't really optimize the performance. I just try to mimic the same thing I did with VGGNet.
I just keep LRN. Nothing special here.
I first converted fc6 and fc7 to convolution layer, and then subsample them -- similar as what I did for VGG. This way, it reduce the number of parameters (i.e. model size), as well as improve the speed a bit.
conv2 has similar role as conv4_3 from VGG.
Hi,
I'm new to SSD and currently trying to use ZFNet instead of VGG to be a base network, and add extra layer still. Unfortunately I could not even reach 50 mAP. I've look into some papers and get to know that even fast Yolo achieves 52.7
Has anybody make an attempt on some other base network and found something useful?
Thanks