Open mrgreen3325 opened 5 years ago
I see, this way could help: use the method in the following link to load images for each batch instead of reading the entire dataset into the RAM at the beginning.
https://github.com/tensorlayer/dcgan/blob/master/data.py#L28
and reduce the prefetch size to 2, remove the shuffle line
The error code is resourse exhausted. But I tried not to preload the whole train set. I used a GTX 1070.
def train():
create folders to save result images and trained model
for idx in range(0, len(train_hr_imgs), batch_size):
step_time = time.time()
b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
update G
errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})
print("Epoch [%2d/%2d] %4d time: %4.4fs, mse: %.8f " % (epoch, n_epoch_init, n_iter, time.time() - step_time, errM))
total_mse_loss += errM
n_iter += 1
log = "[*] Epoch: [%2d/%2d] time: %4.4fs, mse: %.8f" % (epoch, n_epoch_init, time.time() - epoch_time, total_mse_loss / n_iter)
print(log)
for idx in range(0, len(train_hr_imgs), batch_size):
step_time = time.time()
b_imgs_384 = tl.prepro.threading_data(train_hr_imgs[idx:idx + batch_size], fn=crop_sub_imgs_fn, is_random=True)
b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
update D
errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
update G
errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
print("Epoch [%2d/%2d] %4d time: %4.4fs, d_loss: %.8f g_loss: %.8f (mse: %.6f vgg: %.6f adv: %.6f)" %
(epoch, n_epoch, n_iter, time.time() - step_time, errD, errG, errM, errV, errA))
total_d_loss += errD
total_g_loss += errG
n_iter += 1
#
log = "[*] Epoch: [%2d/%2d] time: %4.4fs, d_loss: %.8f g_loss: %.8f" % (epoch, n_epoch, time.time() - epoch_time, total_d_loss / n_iter,
total_g_loss / n_iter)
print(log)
def evaluate():
create folders to save result images
for im in train_hr_imgs:
print(im.shape)
valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
for im in valid_lr_imgs:
print(im.shape)
valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
for im in valid_hr_imgs:
print(im.shape)
if name == 'main': import argparse parser = argparse.ArgumentParser()