Closed PistonY closed 6 years ago
And there are some confuse on this implement.
I implement se_layer like this, is it right?
class SELayer(nn.HybridBlock):
def __init__(self, channel, in_channel, reduction=16, **kwargs):
super(SELayer, self).__init__(**kwargs)
with self.name_scope():
self.avg_pool = nn.GlobalAvgPool2D()
self.fc = nn.HybridSequential()
with self.fc.name_scope():
self.fc.add(nn.Conv2D(channel // reduction, kernel_size=1, in_channels=in_channel))
self.fc.add(nn.PReLU())
self.fc.add(nn.Conv2D(channel, kernel_size=1, in_channels=channel // reduction))
self.fc.add(nn.Activation('sigmoid'))
def hybrid_forward(self, F, x, *args, **kwargs):
y = self.avg_pool(x)
y = self.fc(y)
return F.broadcast_mul(x, y)
For question 1. I did some experiments. I use this network for all experiments.
class LeNet_m(nn.HybridBlock):
r"""LeNet_m model is lenet++ from
`"A Discriminative Feature Learning Approach for Deep Face Recognition"
<https://ydwen.github.io/papers/WenECCV16.pdf>`_ paper.
Parameters
----------
embedding_size : int
Number units of embedding layer.
"""
def __init__(self, embedding_size=2, **kwargs):
super().__init__(**kwargs)
self.feature = nn.HybridSequential()
self.feature.add(
nn.Conv2D(32, 5, padding=2, strides=1),
nn.PReLU(),
nn.Conv2D(32, 5, padding=2, strides=1),
nn.PReLU(),
nn.MaxPool2D(2, strides=2),
nn.Conv2D(64, 5, padding=2, strides=1),
nn.PReLU(),
nn.Conv2D(64, 5, padding=2, strides=1),
nn.PReLU(),
nn.MaxPool2D(2, strides=2),
nn.Conv2D(128, 5, padding=2, strides=1),
nn.PReLU(),
nn.Conv2D(128, 5, padding=2, strides=1),
nn.PReLU(),
nn.MaxPool2D(2, strides=2),
nn.Flatten(),
nn.Dense(embedding_size),
nn.PReLU()
)
self.output = NormDense(10, weight_norm=True, feature_norm=True, in_units=embedding_size)
def hybrid_forward(self, F, x, *args, **kwargs):
embedding = self.feature(x)
output = self.output(embedding)
return embedding, output
When using SoftmaxCrossEntropyLoss, using nn.Flatten() or not seems all right. Using nn.Flatten().
[epoch 0] train accuracy: 0.539571, train loss: 1.698064 | val accuracy: 0.728500, val loss: 1.590276, time: 22.469507
[epoch 1] train accuracy: 0.762123, train loss: 1.581157 | val accuracy: 0.776400, val loss: 1.577212, time: 15.238148
[epoch 2] train accuracy: 0.800143, train loss: 1.569461 | val accuracy: 0.824400, val loss: 1.566522, time: 16.437150
[epoch 3] train accuracy: 0.814001, train loss: 1.560764 | val accuracy: 0.802400, val loss: 1.564648, time: 15.852416
[epoch 4] train accuracy: 0.810616, train loss: 1.558472 | val accuracy: 0.780500, val loss: 1.560144, time: 16.293747
[epoch 5] train accuracy: 0.798843, train loss: 1.554051 | val accuracy: 0.789400, val loss: 1.559037, time: 21.917863
[epoch 6] train accuracy: 0.786903, train loss: 1.553903 | val accuracy: 0.756500, val loss: 1.557781, time: 15.730007
[epoch 7] train accuracy: 0.795424, train loss: 1.550597 | val accuracy: 0.765000, val loss: 1.559362, time: 15.795394
[epoch 8] train accuracy: 0.792006, train loss: 1.549249 | val accuracy: 0.800600, val loss: 1.555520, time: 15.878381
[epoch 9] train accuracy: 0.796041, train loss: 1.548154 | val accuracy: 0.772800, val loss: 1.555308, time: 16.775593
[epoch 10] train accuracy: 0.795291, train loss: 1.547209 | val accuracy: 0.772900, val loss: 1.554858, time: 21.924227
not using nn.Flatten().
[epoch 0] train accuracy: 0.573589, train loss: 1.666234 | val accuracy: 0.754200, val loss: 1.582964, time: 21.863318
[epoch 1] train accuracy: 0.760572, train loss: 1.577847 | val accuracy: 0.796100, val loss: 1.572719, time: 15.935016
[epoch 2] train accuracy: 0.812383, train loss: 1.566417 | val accuracy: 0.827500, val loss: 1.566119, time: 15.978556
[epoch 3] train accuracy: 0.538070, train loss: 1.761760 | val accuracy: 0.750600, val loss: 1.582934, time: 15.977902
[epoch 4] train accuracy: 0.771078, train loss: 1.574963 | val accuracy: 0.771600, val loss: 1.573348, time: 16.360459
[epoch 5] train accuracy: 0.801944, train loss: 1.565094 | val accuracy: 0.820000, val loss: 1.562364, time: 21.755370
[epoch 6] train accuracy: 0.824556, train loss: 1.558800 | val accuracy: 0.827400, val loss: 1.561051, time: 16.179930
[epoch 7] train accuracy: 0.839698, train loss: 1.555389 | val accuracy: 0.837700, val loss: 1.557494, time: 16.182189
[epoch 8] train accuracy: 0.841165, train loss: 1.555763 | val accuracy: 0.813900, val loss: 1.564987, time: 16.511432
[epoch 9] train accuracy: 0.843617, train loss: 1.554402 | val accuracy: 0.839500, val loss: 1.555551, time: 16.007684
[epoch 10] train accuracy: 0.859925, train loss: 1.550805 | val accuracy: 0.844500, val loss: 1.555057, time: 21.957124
But the situation is different when the loss is not SoftmaxCrossEntropyLoss, like ArcFace which I define like this.
class ArcLoss(SoftmaxCrossEntropyLoss):
r"""ArcLoss from
`"ArcFace: Additive Angular Margin Loss for Deep Face Recognition"
<https://arxiv.org/abs/1801.07698>`_ paper.
Parameters
----------
:param s: int. Scale parameter for loss.
:param m:
"""
def __init__(self, classes, s, m, easy_margin=True,
axis=-1, sparse_label=True, weight=None, batch_axis=0, **kwargs):
super().__init__(axis=axis, sparse_label=sparse_label,
weight=weight, batch_axis=batch_axis, **kwargs)
assert s > 0.
assert 0 <= m < (math.pi / 2)
self.s = s
self.m = m
self.cos_m = math.cos(m)
self.sin_m = math.sin(m)
self.mm = math.sin(math.pi - m) * m
self.threshold = math.cos(math.pi - m)
self._classes = classes
self.easy_margin = easy_margin
def hybrid_forward(self, F, pred, label, sample_weight=None, *args, **kwargs):
cos_t = F.pick(pred, label, axis=1) # cos(theta_yi)
if self.easy_margin:
cond = F.Activation(data=cos_t, act_type='relu')
else:
cond_v = cos_t - self.threshold
cond = F.Activation(data=cond_v, act_type='relu')
sin_t = 1.0 - F.sqrt(cos_t * cos_t) # sin(theta)
new_zy = cos_t * self.cos_m - sin_t * self.sin_m # cos(theta_yi + m)
if self.easy_margin:
zy_keep = cos_t
else:
zy_keep = cos_t - self.mm # (cos(theta_yi) - sin(pi - m)*m)
new_zy = F.where(cond, new_zy, zy_keep)
diff = new_zy - cos_t # cos(theta_yi + m) - cos(theta_yi)
diff = F.expand_dims(diff, 1) # shape=(b, 1)
gt_one_hot = F.one_hot(label, depth=self._classes, on_value=1.0, off_value=0.0) # shape=(b,classes)
body = F.broadcast_mul(gt_one_hot, diff)
pred = pred + body
pred = pred * self.s
return super().hybrid_forward(F, pred=pred, label=label, sample_weight=sample_weight)
When do not use nn.Flatten(),the network can't converge any more.
[epoch 0] train accuracy: 0.126467, train loss: 10.330848 | val accuracy: 0.143100, val loss: 10.031983, time: 21.797966
[epoch 1] train accuracy: 0.130670, train loss: 10.031971 | val accuracy: 0.134700, val loss: 10.031968, time: 15.681897
[epoch 2] train accuracy: 0.129202, train loss: 10.031968 | val accuracy: 0.160600, val loss: 10.031967, time: 15.259615
[epoch 3] train accuracy: 0.128919, train loss: 10.031967 | val accuracy: 0.141300, val loss: 10.031965, time: 16.882991
[epoch 4] train accuracy: 0.126384, train loss: 10.031966 | val accuracy: 0.128000, val loss: 10.031965, time: 16.948013
[epoch 5] train accuracy: 0.129236, train loss: 10.031966 | val accuracy: 0.136300, val loss: 10.031962, time: 22.351207
[epoch 6] train accuracy: 0.131503, train loss: 10.031965 | val accuracy: 0.145300, val loss: 10.031964, time: 16.403721
[epoch 7] train accuracy: 0.130203, train loss: 10.031964 | val accuracy: 0.133000, val loss: 10.031961, time: 16.072814
[epoch 8] train accuracy: 0.130420, train loss: 10.031963 | val accuracy: 0.121500, val loss: 10.031962, time: 16.801162
[epoch 9] train accuracy: 0.130786, train loss: 10.031963 | val accuracy: 0.126100, val loss: 10.031961, time: 16.463526
[epoch 10] train accuracy: 0.128935, train loss: 10.031962 | val accuracy: 0.163800, val loss: 10.031962, time: 21.527128
[epoch 11] train accuracy: 0.132221, train loss: 10.031961 | val accuracy: 0.130500, val loss: 10.031959, time: 17.993142
But if you using nn.Flatten() the network will come back normal.
[epoch 0] train accuracy: 0.116379, train loss: 10.226152 | val accuracy: 0.116600, val loss: 10.031872, time: 21.957492
[epoch 1] train accuracy: 0.117946, train loss: 10.031844 | val accuracy: 0.125400, val loss: 10.031767, time: 16.485361
[epoch 2] train accuracy: 0.122549, train loss: 10.031726 | val accuracy: 0.131900, val loss: 10.031610, time: 15.446071
[epoch 3] train accuracy: 0.126251, train loss: 10.031513 | val accuracy: 0.133200, val loss: 10.031284, time: 15.949938
[epoch 4] train accuracy: 0.130019, train loss: 10.030997 | val accuracy: 0.140800, val loss: 10.030234, time: 17.152459
[epoch 5] train accuracy: 0.128752, train loss: 10.028063 | val accuracy: 0.162600, val loss: 10.020156, time: 21.593948
[epoch 6] train accuracy: 0.307531, train loss: 9.088968 | val accuracy: 0.431600, val loss: 8.001767, time: 16.123604
[epoch 7] train accuracy: 0.626851, train loss: 6.378452 | val accuracy: 0.680800, val loss: 5.582754, time: 16.175965
[epoch 8] train accuracy: 0.761856, train loss: 4.965634 | val accuracy: 0.825000, val loss: 4.025213, time: 16.160648
[epoch 9] train accuracy: 0.896561, train loss: 2.909804 | val accuracy: 0.970000, val loss: 1.465720, time: 16.322912
[epoch 10] train accuracy: 0.958761, train loss: 1.774359 | val accuracy: 0.973000, val loss: 1.219243, time: 21.532821
[epoch 11] train accuracy: 0.975303, train loss: 1.099493 | val accuracy: 0.982500, val loss: 0.850762, time: 16.211267
To void randomness experiments,I made it many times on ArcLoss. I don't know why SoftmaxCrossEntropyLoss can make them(using nn.Flatten() or not) all work,but there must be some differences here.I hope you can check it out and better to give some warnings or errors when people forget it.
nn.Dense()
has an parameter flatten
, and it is by default set to True
to do the same thing as nn.Flatten()
. This is the reason that we don't add nn.Flatten()
explicitly. However, your results are unexpected and I encourage you to open an issue at https://github.com/apache/incubator-mxnet/issuesAdaptiveAvgPooling2D
is faster than GlobalAvgPool2D
when back-propagating. Related issue: https://github.com/apache/incubator-mxnet/issues/10912fine.
https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/se_resnet.py#L88