pudae / tensorflow-densenet

Tensorflow-DenseNet with ImageNet Pretrained Models
Apache License 2.0
169 stars 59 forks source link

data_format arg_scope #12

Closed ilkarman closed 6 years ago

ilkarman commented 6 years ago

Pudae, apologies you mentioned this in another issue but I couldn't get it work.

If I create symbol like so:

densenet.densenet_arg_scope(data_format='NCHW')

I get

TypeError: densenet_arg_scope() got an unexpected keyword argument 'data_format'

If I do this:

        dense_args = densenet.densenet_arg_scope()
        dense_args['data_format'] = "NCHW"  # This doesn't work!
        print(dense_args)
        with slim.arg_scope(dense_args):
            base_model, _ = densenet.densenet121(in_tensor,
                                                 num_classes=out_features,
                                                 is_training=is_training)

Tensorflow just seems to ignore it and use NHWC

pudae commented 6 years ago

I Added feature for supporting NCHW. [Support NCHW] Please check if it helps you. Thank you.

ilkarman commented 6 years ago

This is perfect thank you! Out of interest I compared to training time of Keras(TF) model and this TF-one was 17min10s and Keras was 18min30s for 5 epochs on 4xP100s (with synthetic data).

I just also wanted to test if passing fused=True to batchnorm would speed up performance

pudae commented 6 years ago

As I know, if possible, slim use "fused batchnorm" by default. If fused parameter is passed as None, slim implementation set it to True. slim batch_norm