czhu95 / ternarynet

Implementation for Trained Ternary Network.
Apache License 2.0
108 stars 41 forks source link

Reproduced ResNet18 CIFAR10 result is 10% lower than reported #7

Open csyhhu opened 6 years ago

csyhhu commented 6 years ago

Hi @czhu95 ,

Thanks for providing the codes!

Recently I use your codes to ternarize a ResNet18 using CIFAR10. Firstly I use tensorpack to train a ResNet18 to validation error as 0.083. However, when I apply this as initial status and ternarize (using the example codes), as I use the default delta t=0.05 in your code, the validation error is always around 0.1843. I tried other t but it is still around 0.18, which is about 10% lower than your paper report.

Is there any tricks or mistake I made?

Best regards, Shangyu

blueardour commented 5 years ago

@csyhhu Do you have any update on the accuracy? I got 87% on CIFAR10, still 4 points lower than the paper.

csyhhu commented 5 years ago

Hi @blueardour I don't have any update since then.

yy665 commented 5 years ago

Hi. @csyhhu @blueardour @czhu95 Thanks for providing the code! I also got 10% accuracy lower than reported. The main problem I have is that some of the trained values for Wp/Wn can be negative, which resulted in two weights having the same sign. Would that be a big deal or I can just ignore it? Do I need to fix it? Is there anything I am missing? I modified from resnet from tensorpack tutorial and applied configs from this repo Sorry, I am new to quantization problems. This might be a dumb question. I would appraciate if you can answer it for me.

Model (ResNet 34 on CIFAR 10): ` def inputs(self): return [tf.TensorSpec([None, 32, 32, 3], tf.float32, 'input'), tf.TensorSpec([None], tf.int32, 'label')]

def build_graph(self, image, label):
    image = image / 128.0
    assert tf.test.is_gpu_available()
    image = tf.transpose(image, [0, 3, 1, 2])

    blocks = [3,4,6,3]

    def new_get_variable(v):
        # don't binarize first and last layer
        if not v.op.name.endswith('W') or 'conv0' in v.op.name or 'fct' in v.op.name:
            return v
        else:
            logger.info("Quantizing weight {}".format(v.op.name))
            return ternarize(v, args.t)
            #return v

    with remap_variables(new_get_variable), argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm],
                  data_format='channels_first'), \
            argscope(Conv2D, use_bias=False):
        logits = (LinearWrap(image)
                  .Conv2D('conv0', 64, 7, strides=2, activation=BNReLU, padding='VALID')
                  .MaxPooling('pool0', 3, strides=2, padding='SAME')
                  .apply2(resnet_group, 'group0', resnet_basicblock, 64, blocks[0], 1)
                  .apply2(resnet_group, 'group1', resnet_basicblock, 128, blocks[1], 2)
                  .apply2(resnet_group, 'group2', resnet_basicblock, 256, blocks[2], 2)
                  .apply2(resnet_group, 'group3', resnet_basicblock, 512, blocks[3], 2)
                  .GlobalAvgPooling('gap')
                  .FullyConnected('linear', 1000)())

    tf.nn.softmax(logits, name='output')

    cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
    cost = tf.reduce_mean(cost, name='cross_entropy_loss')

    wrong = tf.cast(tf.logical_not(tf.nn.in_top_k(logits, label, 1)), tf.float32, name='wrong_vector')
    # monitor training error
    add_moving_summary(tf.reduce_mean(wrong, name='train_error'))

    # weight decay on all W of fc layers
    wd_w = tf.train.exponential_decay(0.0002, get_global_step_var(),
                                      480000, 0.2, True)
    wd_cost = tf.multiply(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
    add_moving_summary(cost, wd_cost)

    add_param_summary(('.*/W', ['histogram']))   # monitor W
    return tf.add_n([cost, wd_cost], name='cost')

def resnet_basicblock(l, ch_out, stride): shortcut = l

l = BatchNorm('bn1', l)

#l = tf.nn.relu(l) 
l = Conv2D('conv1', l, ch_out, 3, strides=stride, activation=BNReLU)
#l = Conv2D('conv1', l, ch_out, 3, strides=stride)
l = Conv2D('conv2', l, ch_out, 3, activation=get_bn(zero_init=True))
out = l + resnet_shortcut(shortcut, ch_out, stride, activation=get_bn(zero_init=False))
return tf.nn.relu(out)

def resnet_shortcut(l, n_out, stride, activation=tf.identity): data_format = get_arg_scope()['Conv2D']['data_format'] n_in = l.get_shape().as_list()[1 if data_format in ['NCHW', 'channels_first'] else 3] if n_in != n_out: # change dimension when channel is not the same return Conv2D('convshortcut', l, n_out, 1, strides=stride, activation=activation) else: return l

def resnet_group(name, l, block_func, features, count, stride): with tf.variable_scope(name): for i in range(0, count): with tf.variable_scope('block{}'.format(i)): l = block_func(l, features, stride if i == 0 else 1) return l `

Final Train Output: ^[[32m[0703 23:52:34 @base.py:275]^[[0m Start Epoch 400 ... ^[[32m[0703 23:52:48 @base.py:285]^[[0m Epoch 400 (global_step 156000) finished, time:13.9 seconds. ^[[32m[0703 23:52:48 @graph.py:73]^[[0m Running Op sync_variables/sync_variables_from_main_tower ... ^[[32m[0703 23:52:50 @saver.py:79]^[[0m Model saved to train_log/cifar10-resnet34-new/model-156000. ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m QueueInput/queue_size: 50 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m cross_entropy_loss: 0.089525 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block0/conv1/Wn: 2.6743 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block0/conv1/Wp: 0.89064 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block0/conv2/Wn: 1.1225 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block0/conv2/Wp: 0.99695 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block1/conv1/Wn: 1.2538 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block1/conv1/Wp: 0.98401 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block1/conv2/Wn: 0.80803 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block1/conv2/Wp: 0.81799 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block2/conv1/Wn: 0.86712 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block2/conv1/Wp: 0.549 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block2/conv2/Wn: 0.40106 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group0/block2/conv2/Wp: 0.70978 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/conv1/Wn: 0.85778 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/conv1/Wp: 0.82855 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/conv2/Wn: 0.88484 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/conv2/Wp: 0.73929 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/convshortcut/Wn: 0.73418 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block0/convshortcut/Wp: 0.53777 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block1/conv1/Wn: 0.33549 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block1/conv1/Wp: 0.39828 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block1/conv2/Wn: 0.41539 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block1/conv2/Wp: 0.31004 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv1/Wn: 0.29682 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv1/Wp: 0.44211 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv1/Wn: 0.29682 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv1/Wp: 0.44211 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv2/Wn: 0.32866 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block2/conv2/Wp: 0.47317 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block3/conv1/Wn: 0.33542 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block3/conv1/Wp: 0.17248 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block3/conv2/Wn: -0.036823 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group1/block3/conv2/Wp: 0.13485 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/conv1/Wn: 0.82662 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/conv1/Wp: 0.6845 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/conv2/Wn: 0.75725 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/conv2/Wp: 0.79772 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/convshortcut/Wn: 0.26197 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block0/convshortcut/Wp: 0.40541 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block1/conv1/Wn: -0.12812 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block1/conv1/Wp: 0.23041 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block1/conv2/Wn: -0.0099889 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block1/conv2/Wp: 0.13627 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block2/conv1/Wn: 0.12555 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block2/conv1/Wp: 0.085846 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block2/conv2/Wn: 0.12783 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block2/conv2/Wp: 0.14273 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block3/conv1/Wn: 0.085759 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block3/conv1/Wp: 0.10036 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block3/conv2/Wn: 0.12674 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block3/conv2/Wp: 0.10762 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block4/conv1/Wn: 0.016841 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block4/conv1/Wp: 0.069272 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block4/conv2/Wn: -0.063544 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block4/conv2/Wp: 0.16016 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block5/conv1/Wn: 0.11573 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block5/conv1/Wp: 0.057985 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block5/conv2/Wn: -0.19735 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group2/block5/conv2/Wp: 0.07127 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/conv1/Wn: 0.63527 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/conv1/Wp: 0.19355 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/conv2/Wn: 0.31411 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/conv2/Wp: -0.1947 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/convshortcut/Wn: 0.54178 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block0/convshortcut/Wp: 0.81051 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block1/conv1/Wn: 0.4934 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block1/conv1/Wp: 0.44481 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block1/conv2/Wn: 0.70687 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block1/conv2/Wp: -0.068037 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block2/conv1/Wn: -0.84213 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block2/conv1/Wp: 0.28623 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block2/conv2/Wn: 0.357 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m group3/block2/conv2/Wp: -0.71154 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m linear/Wn: 0.76216 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m linear/Wp: 0.86959 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m train_error: 0.02892 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m validation_cost: 0.62721 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m validation_error: 0.139 ^[[32m[0703 23:52:52 @monitor.py:467]^[[0m wd_cost: 30406 ^[[32m[0703 23:52:52 @group.py:48]^[[0m Callbacks took 3.280 sec in total. ModelSaver: 1.75 seconds; InferenceRunner: 1.42 seconds ^[[32m[0703 23:52:52 @base.py:289]^[[0m Training has finished!

blueardour commented 5 years ago

@Yulun-Yao Hi, I'm very sorry that I've gaven up the TTN. I take weeks and tried different optimizer strategy as well as modified the gradient based on empirecal experiences gathered so far, I stilled found the traning unstable and not able to recover the accuracy.

Recently, I moved to LQ-net and Dorefa. I always obtained accuracy better than the paper reported ones without too much efforts on many scanorios. Even for a2w1 or a1w1 bit configurations, I got better result than TTN.

yy665 commented 5 years ago

@Yulun-Yao Hi, I'm very sorry that I've gaven up the TTN. I take weeks and tried different optimizer strategy as well as modified the gradient based on empirecal experiences gathered so far, I stilled found the traning unstable and not able to recover the accuracy.

Recently, I moved to LQ-net and Dorefa. I always obtained accuracy better than the paper reported ones without too much efforts on many scanorios. Even for a2w1 or a1w1 bit configurations, I got better result than TTN.

Thank you for your reply and suggestions!

Btw, have you ever encountered the problem I mentioned above? (Negative scaling factors resulted in both weights having the same sign). If you did encounter it, were you able to fix it? Would you mind sharing your model and parameters?

csyhhu commented 5 years ago

Hi @Yulun-Yao , sorry for replying late. I also give up TTQ. For quantization problems, maybe you can add my wechat: csyhhu for further discuss.