tensorlayer / SRGAN

Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
https://github.com/tensorlayer/tensorlayerx
3.31k stars 811 forks source link

resources exhausted #139

Open mrgreen3325 opened 5 years ago

mrgreen3325 commented 5 years ago

The error code is resourse exhausted. But I tried not to preload the whole train set. I used a GTX 1070.

def train():

create folders to save result images and trained model

save_dir_ginit = "samples/{}_ginit".format(tl.global_flag['mode'])
save_dir_gan = "samples/{}_gan".format(tl.global_flag['mode'])
tl.files.exists_or_mkdir(save_dir_ginit)
tl.files.exists_or_mkdir(save_dir_gan)
checkpoint_dir = "checkpoint"  # checkpoint_resize_conv
tl.files.exists_or_mkdir(checkpoint_dir)

###====================== PRE-LOAD DATA ===========================###
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

## If your machine have enough memory, please pre-load the whole train set.
#train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
#for im in train_hr_imgs:
    #print(im.shape)
#valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
#for im in valid_lr_imgs:
    #print(im.shape)
#valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
#for im in valid_hr_imgs:
    #print(im.shape)
# exit()

###========================== DEFINE MODEL ============================###
## train inference
t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')

net_g = SRGAN_g(t_image, is_train=True, reuse=False)
net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
_, logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)

net_g.print_params(False)
net_g.print_layers()
net_d.print_params(False)
net_d.print_layers()

## vgg inference. 0, 1, 2, 3 BILINEAR NEAREST BICUBIC AREA
t_target_image_224 = tf.image.resize_images(
    t_target_image, size=[224, 224], method=0,
    align_corners=False)  # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False)  # resize_generate_image_for_vgg

net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224 + 1) / 2, reuse=False)
_, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224 + 1) / 2, reuse=True)

## test inference
net_g_test = SRGAN_g(t_image, is_train=False, reuse=True)

# ###========================== DEFINE TRAIN OPS ==========================###
d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')
d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')
d_loss = d_loss1 + d_loss2

g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
mse_loss = tl.cost.mean_squared_error(net_g.outputs, t_target_image, is_mean=True)
vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

g_loss = mse_loss + vgg_loss + g_gan_loss

g_vars = tl.layers.get_variables_with_name('SRGAN_g', True, True)
d_vars = tl.layers.get_variables_with_name('SRGAN_d', True, True)

with tf.variable_scope('learning_rate'):
    lr_v = tf.Variable(lr_init, trainable=False)
## Pretrain
g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
## SRGAN
g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars)
d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars)

###========================== RESTORE MODEL =============================###
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
tl.layers.initialize_global_variables(sess)
if tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), network=net_g) is False:
    tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), network=net_g)
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), network=net_d)

###============================= LOAD VGG ===============================###
vgg19_npy_path = "vgg19.npy"
if not os.path.isfile(vgg19_npy_path):
    print("Please download vgg19.npz from : https://github.com/machrisaa/tensorflow-vgg")
    exit()
npz = np.load(vgg19_npy_path, encoding='latin1').item()

params = []
for val in sorted(npz.items()):
    W = np.asarray(val[1][0])
    b = np.asarray(val[1][1])
    print("  Loading %s: %s, %s" % (val[0], W.shape, b.shape))
    params.extend([W, b])
tl.files.assign_params(sess, params, net_vgg)
# net_vgg.print_params(False)
# net_vgg.print_layers()

###============================= TRAINING ===============================###
## use first `batch_size` of train set to have a quick test during training
sample_imgs = train_hr_imgs[0:batch_size]
# sample_imgs = tl.vis.read_images(train_hr_img_list[0:batch_size], path=config.TRAIN.hr_img_path, n_threads=32) # if no pre-load train set
sample_imgs_384 = tl.prepro.threading_data(sample_imgs, fn=crop_sub_imgs_fn, is_random=False)
print('sample HR sub-image:', sample_imgs_384.shape, sample_imgs_384.min(), sample_imgs_384.max())
sample_imgs_96 = tl.prepro.threading_data(sample_imgs_384, fn=downsample_fn)
print('sample LR sub-image:', sample_imgs_96.shape, sample_imgs_96.min(), sample_imgs_96.max())
tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_ginit + '/_train_sample_96.png')
tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_ginit + '/_train_sample_384.png')
tl.vis.save_images(sample_imgs_96, [ni, ni], save_dir_gan + '/_train_sample_96.png')
tl.vis.save_images(sample_imgs_384, [ni, ni], save_dir_gan + '/_train_sample_384.png')

###========================= initialize G ====================###
## fixed learning rate
sess.run(tf.assign(lr_v, lr_init))
print(" ** fixed learning rate: %f (for init G)" % lr_init)
for epoch in range(0, n_epoch_init + 1):
    epoch_time = time.time()
    total_mse_loss, n_iter = 0, 0

    ## If your machine cannot load all images into memory, you should use
    ## this one to load batch of images while training.
    random.shuffle(train_hr_img_list)
    for idx in range(0, len(train_hr_img_list), batch_size):
        step_time = time.time()
        b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

    ## If your machine have enough memory, please pre-load the whole train set.

for idx in range(0, len(train_hr_imgs), batch_size):

step_time = time.time()

b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)

b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

update G

errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})

print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))

total_mse_loss += errM

n_iter += 1

log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)

print(log)

    ## quick evaluation on train set
    if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
        print("[*] save images")
        tl.vis.save_images(out, [ni, ni], save_dir_ginit + '/train_%d.png' % epoch)

    ## save model
    if (epoch != 0) and (epoch % 10 == 0):
        tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}_init.npz'.format(tl.global_flag['mode']), sess=sess)

###========================= train GAN (SRGAN) =========================###
for epoch in range(0, n_epoch + 1):
    ## update learning rate
    if epoch != 0 and (epoch % decay_every == 0):
        new_lr_decay = lr_decay**(epoch // decay_every)
        sess.run(tf.assign(lr_v, lr_init * new_lr_decay))
        log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
        print(log)
    elif epoch == 0:
        sess.run(tf.assign(lr_v, lr_init))
        log = " ** init lr: %f  decay_every_init: %d, lr_decay: %f (for GAN)" % (lr_init, decay_every, lr_decay)
        print(log)

    epoch_time = time.time()
    total_d_loss, total_g_loss, n_iter = 0, 0, 0

    ## If your machine cannot load all images into memory, you should use
    ## this one to load batch of images while training.
    random.shuffle(train_hr_img_list)
    for idx in range(0, len(train_hr_img_list), batch_size):
        step_time = time.time()
        b_imgs_list = train_hr_img_list[idx : idx + batch_size]
        b_imgs = tl.prepro.threading_data(b_imgs_list, fn=get_imgs_fn, path=config.TRAIN.hr_img_path)
        b_imgs_384 = tl.prepro.threading_data(b_imgs, fn=crop_sub_imgs_fn, is_random=True)
        b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

    ## If your machine have enough memory, please pre-load the whole train set.

for idx in range(0, len(train_hr_imgs), batch_size):

step_time = time.time()

b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)

b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)

update D

errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})

update G

errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})

print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" %

(epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA))

total_d_loss += errD

total_g_loss += errG

n_iter += 1

#

log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,

total_g_loss / n_iter)

print(log)

    ## quick evaluation on train set
    if (epoch != 0) and (epoch % 10 == 0):
        out = sess.run(net_g_test.outputs, {t_image: sample_imgs_96})  #; print('gen sub-image:', out.shape, out.min(), out.max())
        print("[*] save images")
        tl.vis.save_images(out, [ni, ni], save_dir_gan + '/train_%d.png' % epoch)

    ## save model
    if (epoch != 0) and (epoch % 10 == 0):
        tl.files.save_npz(net_g.all_params, name=checkpoint_dir + '/g_{}.npz'.format(tl.global_flag['mode']), sess=sess)
        tl.files.save_npz(net_d.all_params, name=checkpoint_dir + '/d_{}.npz'.format(tl.global_flag['mode']), sess=sess)

def evaluate():

create folders to save result images

save_dir = "samples/{}".format(tl.global_flag['mode'])
tl.files.exists_or_mkdir(save_dir)
checkpoint_dir = "checkpoint"

###====================== PRE-LOAD DATA ===========================###
# train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
# train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

## If your machine have enough memory, please pre-load the whole train set.
train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)

for im in train_hr_imgs:

print(im.shape)

valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)

for im in valid_lr_imgs:

print(im.shape)

valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)

for im in valid_hr_imgs:

print(im.shape)

# exit()

###========================== DEFINE MODEL ============================###
imid = 64  # 0: 企鹅  81: 蝴蝶 53: 鸟  64: 古堡
valid_lr_img = valid_lr_imgs[imid]
valid_hr_img = valid_hr_imgs[imid]
# valid_lr_img = get_imgs_fn('test.png', 'data2017/')  # if you want to test your own image
valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to [-1, 1]
# print(valid_lr_img.min(), valid_lr_img.max())

size = valid_lr_img.shape
# t_image = tf.placeholder('float32', [None, size[0], size[1], size[2]], name='input_image') # the old version of TL need to specify the image size
t_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')

net_g = SRGAN_g(t_image, is_train=False, reuse=False)

###========================== RESTORE G =============================###
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
tl.layers.initialize_global_variables(sess)
tl.files.load_and_assign_npz(sess=sess, name=checkpoint_dir + '/g_srgan.npz', network=net_g)

###======================= EVALUATION =============================###
start_time = time.time()
out = sess.run(net_g.outputs, {t_image: [valid_lr_img]})
print("took: %4.4fs" % (time.time() - start_time))

print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
print("[*] save images")
tl.vis.save_image(out[0], save_dir + '/valid_gen.png')
tl.vis.save_image(valid_lr_img, save_dir + '/valid_lr.png')
tl.vis.save_image(valid_hr_img, save_dir + '/valid_hr.png')

out_bicu = scipy.misc.imresize(valid_lr_img, [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
tl.vis.save_image(out_bicu, save_dir + '/valid_bicubic.png')

if name == 'main': import argparse parser = argparse.ArgumentParser()

parser.add_argument('--mode', type=str, default='srgan', help='srgan, evaluate')

args = parser.parse_args()

tl.global_flag['mode'] = args.mode

if tl.global_flag['mode'] == 'srgan':
    train()
elif tl.global_flag['mode'] == 'evaluate':
    evaluate()
else:
    raise Exception("Unknow --mode")
zsdonghao commented 5 years ago

I see, this way could help: use the method in the following link to load images for each batch instead of reading the entire dataset into the RAM at the beginning.

https://github.com/tensorlayer/dcgan/blob/master/data.py#L28

and reduce the prefetch size to 2, remove the shuffle line