Open 111xumengze opened 4 years ago
This error should be fixed. I just deleted the variable threshold because I don't have a running environment at this point.
hello sir, im also getting error at that line how to rectify it?dataset=dataset.map(_parse_function(threshold))
hello sir, im also getting error at that line how to rectify it?dataset=dataset.map(_parse_function(threshold))
just like this: dataset = dataset.map(lambda x: _parse_function(x, threshold))
thanks. then in train.py 9th line it should be from network not networks right?
thanks. then in train.py 9th line it should be from network not networks right?
yes
This error is coming in train.py file saver.save(sess, ckptPath, globalStep=globalStep, write_meta_graph=False) TypeError: save() got an unexpected keyword argument 'globalStep'
This error is coming in train.py file saver.save(sess, ckptPath, globalStep=globalStep, write_meta_graph=False) TypeError: save() got an unexpected keyword argument 'globalStep'
`
tfrecordDir = './tf_record_dir/train_co'
# The prefix of file
trainFile = 'train'
checkpointDir = './output_dir/models'
logDir = './output_dir/logs/step_accuracy.txt'
num_classes = 2
is_training = True
bacthSize = 56
threshold = 192
howManyTimeShuffleFile = 1
howManyRepeatFileList = 80
initialLearningRate = 0.0005
learningRateDecayFactor = 0.85
keep_prob_set = 0.5
config = tf.ConfigProto(log_device_placement=False)
config.gpu_options.allow_growth = True
fileList = getNumEpochTfrecordWithShuffle(tfrecordDir, howManyTimeShuffleFile)
# Read tfrecord_dir file
dataset = tf.data.TFRecordDataset(fileList)
dataset = dataset.map(lambda x: _parse_function(x, threshold))
dataset = dataset.repeat(howManyRepeatFileList)
dataset = dataset.shuffle(buffer_size=3000)
dataset = dataset.batch(bacthSize)
iterator = dataset.make_one_shot_iterator()
nextImgs, nextLabels = iterator.get_next()
def getImgesLabels(sess):
imgs, labels = sess.run([nextImgs, nextLabels])
return imgs, labels
# Get how many batches and samples in a epoch
numSamplesPerEpoch = getTheNumOfImgInAEpoch(trainFile)
numBatchesPerEpoch = numSamplesPerEpoch // bacthSize
decay_steps = 600
# Set the verbosity to INFO level
tf.logging.set_verbosity(tf.logging.INFO)
IMGS = tf.placeholder(tf.float32, (bacthSize, threshold, threshold, 8))
Labels = tf.placeholder(tf.int32, (bacthSize,))
keep_prob = tf.placeholder(tf.float32)
print("numSamplesPerEpoch, numBatchesPerEpoch, howManyTimeShuffleFile, howManyRepeatFileList: ",
numSamplesPerEpoch, numBatchesPerEpoch, howManyTimeShuffleFile, howManyRepeatFileList)
# Set default configuration
with slim.arg_scope(DNNs_arg_scope()):
logits, end_points = DNNs(IMGS, num_classes=num_classes, keep_prob=keep_prob, is_training=is_training)
# Labels to one-hot encoding
one_hot_labels = slim.one_hot_encoding(Labels, num_classes)
cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits) #
totalLoss = tf.losses.get_total_loss()
globalStep = get_or_create_global_step()
# decayed_learning_rate = learining_rate * decay_rate ^ ( global_step/decay_steps )
lr = tf.train.exponential_decay(
learning_rate=initialLearningRate,
global_step=globalStep,
decay_steps=decay_steps,
decay_rate=learningRateDecayFactor,
staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
# Create the trainOp.
trainOp = slim.learning.create_train_op(totalLoss, optimizer)
predictions = tf.argmax(end_points['Predictions'], 1)
# The probabilities of the samples in a batch
probabilities = end_points['Predictions']
accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, Labels)
metrics_op = tf.group(accuracy_update, probabilities)
def train_step(sess, trainOp, globalStep, imgs, labels):
start_time = time.time()
totalLoss, globalStepCount, _ = sess.run([trainOp, globalStep, metrics_op],
feed_dict={IMGS: imgs, Labels: labels, keep_prob: keep_prob_set})
time_elapsed = time.time() - start_time
# Run the logging to print some results
logging.info('global step %s: loss: %.4f (%.2f sec/step)', globalStepCount, totalLoss, time_elapsed)
# return total loss and which step
return totalLoss, globalStepCount
if not os.path.exists(checkpointDir):
os.mkdir(checkpointDir)
saver = tf.train.Saver(max_to_keep=200)
ckptPath = os.path.join(checkpointDir, 'mode.ckpt')
lossAvg = 0
lossTotal = 0
with tf.Session(config=config) as sess:
sess.run(tf.initialize_all_variables())
sess.run(tf.initialize_local_variables())
if os.listdir(checkpointDir):
model_file = tf.train.latest_checkpoint(checkpointDir)
saver.restore(sess, model_file)
print('total samples: %d' % (numBatchesPerEpoch * howManyTimeShuffleFile * howManyRepeatFileList))
for step in range(numBatchesPerEpoch * howManyTimeShuffleFile * howManyRepeatFileList):
imgs, labels = getImgesLabels(sess)
# for each epoch
if step % numBatchesPerEpoch == 0:
logging.info('Epoch %s/%s', step / numBatchesPerEpoch + 1,
howManyTimeShuffleFile * howManyRepeatFileList)
learning_rate_value, accuracy_value = sess.run([lr, accuracy], feed_dict={IMGS: imgs, Labels: labels,
keep_prob: keep_prob_set})
logging.info('Current Learning Rate: %s', learning_rate_value)
logging.info('Current Streaming Accuracy: %s', accuracy_value)
logits_value, probabilities_value, predictions_value = sess.run([logits, probabilities, predictions],
feed_dict={IMGS: imgs, Labels: labels,
keep_prob: keep_prob_set})
print('logits: \n', logits_value[:5])
print('Probabilities: \n', probabilities_value[:5])
print('lables: :{}\n'.format(labels))
print('predictions:{}\n'.format(predictions_value))
print(lossTotal)
lossTotal = 0
loss, globalStepCount = train_step(sess, trainOp, globalStep, imgs, labels)
lossAvg += loss
lossTotal += loss
# how many step to save a ckpt file
if step % 150 == 0:
learning_rate_value, accuracy_value = sess.run([lr, accuracy], feed_dict={IMGS: imgs, Labels: labels})
print('learning_rate_value: {}\n accuracy_value: {}'.format(learning_rate_value, accuracy_value))
with open(logDir, 'a+') as File:
line_str = 'learining_rate: ' + str(learning_rate_value) + ' global_step: ' + str(
globalStepCount) + ' loss: ' + str(lossAvg / 150) + ' accuracy_value: ' + str(
accuracy_value) + '\n'
print(line_str)
File.write(line_str)
lossAvg = 0
saver.save(sess, ckptPath, global_step=globalStep, write_meta_graph=False)
if not os.path.exists(checkpointDir + '/*.meta'):
saver.export_meta_graph(checkpointDir + '/mode.ckpt.meta')
`
Thanks Sir
Is it right? Till how many steps it will be going?
Is it right? Till how many steps it will be going?
numBatchesPerEpoch howManyTimeShuffleFile howManyRepeatFileList
What happens if I reduce the epoch? It is set to 80. Can I reduce to minimum number?
x is missed. And thank you for your code and paper!