yuleung / image_forensics

A general approach to detect tampered and generated image
8 stars 4 forks source link

TypeError: _parse_function() missing 1 required positional argument: 'threshold' #2

Open 111xumengze opened 4 years ago

111xumengze commented 4 years ago
dataset = dataset.map(_parse_function(threshold))       

x is missed. And thank you for your code and paper!

yuleung commented 4 years ago

This error should be fixed. I just deleted the variable threshold because I don't have a running environment at this point.

dhvaraca commented 4 years ago

hello sir, im also getting error at that line how to rectify it?dataset=dataset.map(_parse_function(threshold))

111xumengze commented 4 years ago

hello sir, im also getting error at that line how to rectify it?dataset=dataset.map(_parse_function(threshold))

just like this: dataset = dataset.map(lambda x: _parse_function(x, threshold))

dhvaraca commented 4 years ago

thanks. then in train.py 9th line it should be from network not networks right?

111xumengze commented 4 years ago

thanks. then in train.py 9th line it should be from network not networks right?

yes

dhvaraca commented 4 years ago

This error is coming in train.py file saver.save(sess, ckptPath, globalStep=globalStep, write_meta_graph=False) TypeError: save() got an unexpected keyword argument 'globalStep'

111xumengze commented 4 years ago

This error is coming in train.py file saver.save(sess, ckptPath, globalStep=globalStep, write_meta_graph=False) TypeError: save() got an unexpected keyword argument 'globalStep'

`

tfrecordDir = './tf_record_dir/train_co'
# The prefix of file
trainFile = 'train'

checkpointDir = './output_dir/models'

logDir = './output_dir/logs/step_accuracy.txt'

num_classes = 2
is_training = True

bacthSize = 56
threshold = 192
howManyTimeShuffleFile = 1
howManyRepeatFileList = 80
initialLearningRate = 0.0005
learningRateDecayFactor = 0.85
keep_prob_set = 0.5
config = tf.ConfigProto(log_device_placement=False)
config.gpu_options.allow_growth = True

fileList = getNumEpochTfrecordWithShuffle(tfrecordDir, howManyTimeShuffleFile)

# Read tfrecord_dir file
dataset = tf.data.TFRecordDataset(fileList)
dataset = dataset.map(lambda x: _parse_function(x, threshold))
dataset = dataset.repeat(howManyRepeatFileList)
dataset = dataset.shuffle(buffer_size=3000)
dataset = dataset.batch(bacthSize)
iterator = dataset.make_one_shot_iterator()
nextImgs, nextLabels = iterator.get_next()

def getImgesLabels(sess):
    imgs, labels = sess.run([nextImgs, nextLabels])
    return imgs, labels

# Get how many batches and samples in a epoch
numSamplesPerEpoch = getTheNumOfImgInAEpoch(trainFile)
numBatchesPerEpoch = numSamplesPerEpoch // bacthSize
decay_steps = 600

# Set the verbosity to INFO level
tf.logging.set_verbosity(tf.logging.INFO)

IMGS = tf.placeholder(tf.float32, (bacthSize, threshold, threshold, 8))
Labels = tf.placeholder(tf.int32, (bacthSize,))
keep_prob = tf.placeholder(tf.float32)

print("numSamplesPerEpoch, numBatchesPerEpoch, howManyTimeShuffleFile, howManyRepeatFileList: ",
      numSamplesPerEpoch, numBatchesPerEpoch, howManyTimeShuffleFile, howManyRepeatFileList)

# Set default configuration
with slim.arg_scope(DNNs_arg_scope()):
    logits, end_points = DNNs(IMGS, num_classes=num_classes, keep_prob=keep_prob, is_training=is_training)

# Labels to one-hot encoding
one_hot_labels = slim.one_hot_encoding(Labels, num_classes)

cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=one_hot_labels, logits=logits)  #
totalLoss = tf.losses.get_total_loss()
globalStep = get_or_create_global_step()

# decayed_learning_rate = learining_rate * decay_rate ^ ( global_step/decay_steps )
lr = tf.train.exponential_decay(
    learning_rate=initialLearningRate,
    global_step=globalStep,
    decay_steps=decay_steps,
    decay_rate=learningRateDecayFactor,
    staircase=True)

optimizer = tf.train.AdamOptimizer(learning_rate=lr)

# Create the trainOp.
trainOp = slim.learning.create_train_op(totalLoss, optimizer)
predictions = tf.argmax(end_points['Predictions'], 1)

# The probabilities of the samples in a batch
probabilities = end_points['Predictions']
accuracy, accuracy_update = tf.contrib.metrics.streaming_accuracy(predictions, Labels)
metrics_op = tf.group(accuracy_update, probabilities)

def train_step(sess, trainOp, globalStep, imgs, labels):
    start_time = time.time()
    totalLoss, globalStepCount, _ = sess.run([trainOp, globalStep, metrics_op],
                                             feed_dict={IMGS: imgs, Labels: labels, keep_prob: keep_prob_set})
    time_elapsed = time.time() - start_time

    # Run the logging to print some results
    logging.info('global step %s: loss: %.4f (%.2f sec/step)', globalStepCount, totalLoss, time_elapsed)
    # return total loss and which step
    return totalLoss, globalStepCount

if not os.path.exists(checkpointDir):
    os.mkdir(checkpointDir)

saver = tf.train.Saver(max_to_keep=200)
ckptPath = os.path.join(checkpointDir, 'mode.ckpt')
lossAvg = 0
lossTotal = 0
with tf.Session(config=config) as sess:

    sess.run(tf.initialize_all_variables())
    sess.run(tf.initialize_local_variables())

    if os.listdir(checkpointDir):
        model_file = tf.train.latest_checkpoint(checkpointDir)
        saver.restore(sess, model_file)
        print('total samples: %d' % (numBatchesPerEpoch * howManyTimeShuffleFile * howManyRepeatFileList))

    for step in range(numBatchesPerEpoch * howManyTimeShuffleFile * howManyRepeatFileList):
        imgs, labels = getImgesLabels(sess)
        # for each epoch
        if step % numBatchesPerEpoch == 0:
            logging.info('Epoch %s/%s', step / numBatchesPerEpoch + 1,
                         howManyTimeShuffleFile * howManyRepeatFileList)
            learning_rate_value, accuracy_value = sess.run([lr, accuracy], feed_dict={IMGS: imgs, Labels: labels,
                                                                                      keep_prob: keep_prob_set})
            logging.info('Current Learning Rate: %s', learning_rate_value)
            logging.info('Current Streaming Accuracy: %s', accuracy_value)
            logits_value, probabilities_value, predictions_value = sess.run([logits, probabilities, predictions],
                                                                            feed_dict={IMGS: imgs, Labels: labels,
                                                                                       keep_prob: keep_prob_set})
            print('logits: \n', logits_value[:5])
            print('Probabilities: \n', probabilities_value[:5])
            print('lables:    :{}\n'.format(labels))
            print('predictions:{}\n'.format(predictions_value))
            print(lossTotal)
            lossTotal = 0

        loss, globalStepCount = train_step(sess, trainOp, globalStep, imgs, labels)
        lossAvg += loss
        lossTotal += loss

        # how many step to save a ckpt file
        if step % 150 == 0:
            learning_rate_value, accuracy_value = sess.run([lr, accuracy], feed_dict={IMGS: imgs, Labels: labels})
            print('learning_rate_value: {}\n accuracy_value: {}'.format(learning_rate_value, accuracy_value))
            with open(logDir, 'a+') as File:
                line_str = 'learining_rate: ' + str(learning_rate_value) + '  global_step: ' + str(
                    globalStepCount) + '  loss: ' + str(lossAvg / 150) + '  accuracy_value: ' + str(
                    accuracy_value) + '\n'
                print(line_str)
                File.write(line_str)
            lossAvg = 0
            saver.save(sess, ckptPath, global_step=globalStep, write_meta_graph=False)
            if not os.path.exists(checkpointDir + '/*.meta'):
                saver.export_meta_graph(checkpointDir + '/mode.ckpt.meta')

`

dhvaraca commented 4 years ago

Thanks Sir

dhvaraca commented 4 years ago

Screenshot (17) Is it right? Till how many steps it will be going?

111xumengze commented 4 years ago

Screenshot (17) Is it right? Till how many steps it will be going?

numBatchesPerEpoch howManyTimeShuffleFile howManyRepeatFileList

dhvaraca commented 4 years ago

What happens if I reduce the epoch? It is set to 80. Can I reduce to minimum number?