tensorflow / tensorflow

An Open Source Machine Learning Framework for Everyone
https://tensorflow.org
Apache License 2.0
184.21k stars 74.08k forks source link

There seems to be a wrong network implementation in tensorflow/example/speech_commands/models.py #19853

Closed smallbal closed 6 years ago

smallbal commented 6 years ago

There may be a mistake I found in tensorflow/tensorflow/example/speech_commands/models.py: The function create_conv_model() is to create a "cnn-trad-fpool3" network with 2 max pool layers, but the second max pool layer is missing now. The code can run normally, but I'm not sure if it is the right "cnn-trad-fpool3" network.

I put part of the code of tensorflow/tensorflow/exmaple/speech_commands/models.py here below and added some comments to explain my points.

def create_conv_model(fingerprint_input, model_settings, is_training):
  """Builds a standard convolutional model.
  This is roughly the network labeled as 'cnn-trad-fpool3' in the
  'Convolutional Neural Networks for Small-footprint Keyword Spotting' paper:
  http://www.isca-speech.org/archive/interspeech_2015/papers/i15_1478.pdf
  Here's the layout of the graph:
  (fingerprint_input)
          v
      [Conv2D]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [MaxPool]
          v
      [Conv2D]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
        [Relu]
          v
      [MaxPool]
          v
      [MatMul]<-(weights)
          v
      [BiasAdd]<-(bias)
          v
  This produces fairly good quality results, but can involve a large number of
  weight parameters and computations. For a cheaper alternative from the same
  paper with slightly less accuracy, see 'low_latency_conv' below.
  During training, dropout nodes are introduced after each relu, controlled by a
  placeholder.
  Args:
    fingerprint_input: TensorFlow node that will output audio feature vectors.
    model_settings: Dictionary of information about the model.
    is_training: Whether the model is going to be used for training.
  Returns:
    TensorFlow node outputting logits results, and optionally a dropout
    placeholder.
  """
  if is_training:
    dropout_prob = tf.placeholder(tf.float32, name='dropout_prob')
  input_frequency_size = model_settings['dct_coefficient_count']
  input_time_size = model_settings['spectrogram_length']
  fingerprint_4d = tf.reshape(fingerprint_input,
                              [-1, input_time_size, input_frequency_size, 1])
  first_filter_width = 8
  first_filter_height = 20
  first_filter_count = 64
  first_weights = tf.Variable(
      tf.truncated_normal(
          [first_filter_height, first_filter_width, 1, first_filter_count],
          stddev=0.01))
  first_bias = tf.Variable(tf.zeros([first_filter_count]))
  first_conv = tf.nn.conv2d(fingerprint_4d, first_weights, [1, 1, 1, 1],
                            'SAME') + first_bias
  first_relu = tf.nn.relu(first_conv)
  if is_training:
    first_dropout = tf.nn.dropout(first_relu, dropout_prob)
  else:
    first_dropout = first_relu
  # the first and only max pool operation in the whole network.
  max_pool = tf.nn.max_pool(first_dropout, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME') 
  second_filter_width = 4
  second_filter_height = 10
  second_filter_count = 64
  second_weights = tf.Variable(
      tf.truncated_normal(
          [
              second_filter_height, second_filter_width, first_filter_count,
              second_filter_count
          ],
          stddev=0.01))
  second_bias = tf.Variable(tf.zeros([second_filter_count]))
  second_conv = tf.nn.conv2d(max_pool, second_weights, [1, 1, 1, 1],
                             'SAME') + second_bias
  second_relu = tf.nn.relu(second_conv)
  if is_training:
    second_dropout = tf.nn.dropout(second_relu, dropout_prob)
  else:
    second_dropout = second_relu
  # there could be another one max pool operation as the comments said in the beginning:
  # another_max_pool =tf.nn.max_pool(second_dropout, [1, 2, 2, 1], [1, 2, 2, 1], 'SAME')
  # and the variable name "second_dropout" below should be changed to "another_max_pool".
  second_conv_shape = second_dropout.get_shape()
  second_conv_output_width = second_conv_shape[2]
  second_conv_output_height = second_conv_shape[1]
  second_conv_element_count = int(
      second_conv_output_width * second_conv_output_height *
      second_filter_count)
  flattened_second_conv = tf.reshape(second_dropout,
                                     [-1, second_conv_element_count])
  label_count = model_settings['label_count']
  final_fc_weights = tf.Variable(
      tf.truncated_normal(
          [second_conv_element_count, label_count], stddev=0.01))
  final_fc_bias = tf.Variable(tf.zeros([label_count]))
  final_fc = tf.matmul(flattened_second_conv, final_fc_weights) + final_fc_bias
  if is_training:
    return final_fc, dropout_prob
  else:
    return final_fc

System information

tensorflowbutler commented 6 years ago

Thank you for your post. We noticed you have not filled out the following field in the issue template. Could you update them if they are relevant in your case, or leave them as N/A? Thanks. Have I written custom code OS Platform and Distribution TensorFlow installed from TensorFlow version Bazel version CUDA/cuDNN version GPU model and memory Exact command to reproduce

smallbal commented 6 years ago

Thanks for replying to my issue, I have updated my comment and added more details.

bignamehyp commented 6 years ago

@petewarden can you please take a look?

tensorflowbutler commented 6 years ago

Nagging Assignee @bignamehyp: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

tensorflowbutler commented 6 years ago

Nagging Assignee @petewarden: It has been 14 days with no activity and this issue has an assignee. Please update the label and/or status accordingly.

petewarden commented 6 years ago

Thanks for spotting this! It does look like a bug, sorry about that. Since we've been using that conv model as a baseline (for example in https://arxiv.org/abs/1804.03209 ) it might be awkward to fix it at this stage. I would welcome a PR with a comment added noting this issue, and possibly an alternative conv_with_maxpool implementation if it does produce noticeably better results. Closing for now, but please refer to this if you do get a chance to work on that.