ogroth / tf-gqn

Tensorflow implementation of Neural Scene Representation and Rendering
Apache License 2.0
188 stars 35 forks source link

How can I get the right result? #79

Open 11lucky111 opened 1 year ago

11lucky111 commented 1 year ago

I use view interpolation notebook to load shepard_metzler_5_parts, But I can't get a correct result. This is my result: view_interpolation_preview Here is my process:

'''imports'''
# stdlib
import os
import sys
import logging
# numerical computing
import numpy as np
import tensorflow as tf
# plotting
import imageio
logging.getLogger("imageio").setLevel(logging.ERROR)  # switch off warnings during lossy GIF-generation
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from IPython.display import Image, display
# GQN src
root_path = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(root_path)
print(sys.path)
from data_provider.gqn_provider import gqn_input_fn
from gqn.gqn_predictor import GqnViewPredictor

['C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\python36.zip', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\DLLs', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew', '', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages\win32', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages\win32\lib', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages\Pythonwin', 'C:\Users\lenovo\AppData\Local\conda\conda\envs\tensorflownew\lib\site-packages\IPython\extensions', 'C:\Users\lenovo\.ipython', 'E:\Desktop\tf-gqn-master\tf-gqn-master']

'''directory setup'''
data_dir = os.path.join(root_path, 'data')
model_dir = os.path.join(root_path, 'models')
tmp_dir = os.path.join(root_path, 'notebooks', 'tmp')
gqn_dataset_path = os.path.join(data_dir, 'gqn-dataset')
# dataset flags
# dataset_name = 'jaco'  # one of the GQN dataset names
# dataset_name = 'rooms_ring_camera'  # one of the GQN dataset names
# dataset_name = 'rooms_free_camera_no_object_rotations'  # one of the GQN dataset names
# dataset_name = 'rooms_free_camera_with_object_rotations'  # one of the GQN dataset names
dataset_name = 'shepard_metzler_5_parts'#'shepard_metzler_5_parts'  # one of the GQN dataset names
# dataset_name = 'shepard_metzler_7_parts'  # one of the GQN dataset names
data_path = os.path.join(gqn_dataset_path, dataset_name)
print("Data path: %s" % (data_path, ))
# model flags
model_name = 'gqn'#'gqn8'
# model_name = 'gqn12'
gqn_model_path = os.path.join(model_dir, dataset_name)
model_path = os.path.join(gqn_model_path, model_name)
print("Model path: %s" % (model_path, ))
# tmp
notebook_name = 'view_interpolation'
notebook_tmp_path = os.path.join(tmp_dir, notebook_name)
os.makedirs(notebook_tmp_path, exist_ok=True)
print("Tmp path: %s" % (notebook_tmp_path, ))

Data path: E:\Desktop\tf-gqn-master\tf-gqn-master\data\gqn-dataset\shepard_metzler_5_parts Model path: E:\Desktop\tf-gqn-master\tf-gqn-master\models\shepard_metzler_5_parts\gqn Tmp path: E:\Desktop\tf-gqn-master\tf-gqn-master\notebooks\tmp\view_interpolation

'''data reader setup'''
mode = tf.estimator.ModeKeys.EVAL
ctx_size=5  # needs to be the same as the context size defined in gqn_config.json in the model_path
batch_size=1  # should be kept at 1
dataset = gqn_input_fn(
        dataset_name=dataset_name, root=gqn_dataset_path, mode=mode,
        context_size=ctx_size, batch_size=batch_size, num_epochs=1,
        num_threads=4, buffer_size=1)
iterator = dataset.make_initializable_iterator()
data = iterator.get_next()
'''video predictor & session setup'''
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # run on CPU only, adjust to GPU id for speedup
#os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
predictor = GqnViewPredictor(model_path)
sess = predictor.sess
sess.run(iterator.initializer)
print("Loop completed.")

**>>> Instantiated GQN: enc_r Tensor("GQN/Sum:0", shape=(1, 1, 1, 256), dtype=float32) canvas_0 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add:0", shape=(1, 64, 64, 256), dtype=float32) canvas_1 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_1:0", shape=(1, 64, 64, 256), dtype=float32) canvas_2 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_2:0", shape=(1, 64, 64, 256), dtype=float32) canvas_3 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_3:0", shape=(1, 64, 64, 256), dtype=float32) canvas_4 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_4:0", shape=(1, 64, 64, 256), dtype=float32) canvas_5 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_5:0", shape=(1, 64, 64, 256), dtype=float32) canvas_6 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_6:0", shape=(1, 64, 64, 256), dtype=float32) canvas_7 Tensor("GQN/GQN_RNN/Generator/LSTM_gen/add_7:0", shape=(1, 64, 64, 256), dtype=float32) mu_target Tensor("GQN/eta_g/BiasAdd:0", shape=(1, 64, 64, 3), dtype=float32) INFO:tensorflow:Restoring parameters from E:\Desktop\tf-gqn-master\tf-gqn-master\models\shepard_metzler_5_parts\gqn\model.ckpt-0 >>> Restored parameters from: E:\Desktop\tf-gqn-master\tf-gqn-master\models\shepard_metzler_5_parts\gqn\model.ckpt-0 Loop completed.**

'''data visualization'''
skip_load = 1  # adjust this to skip through records
print("Loop completed.")
# fetch & parse
for _ in range(skip_load):
    d, _ = sess.run(data)
ctx_frames = d.query.context.frames
ctx_poses = d.query.context.cameras
tgt_frame = d.target
tgt_pose = d.query.query_camera
tuple_length = ctx_size + 1  # context points + 1 target

print(">>> Context frames:\t%s" % (ctx_frames.shape, ))
print(">>> Context poses: \t%s" % (ctx_poses.shape, ))
print(">>> Target frame:  \t%s" % (tgt_frame.shape, ))
print(">>> Target pose:   \t%s" % (tgt_pose.shape, ))

# visualization constants
MAX_COLS_PER_ROW = 6
TILE_HEIGHT, TILE_WIDTH, TILE_PAD = 3.0, 3.0, 0.8
np.set_printoptions(precision=2, suppress=True)

# visualize all data tuples in the batch
for n in range(batch_size):
    # define image grid
    ncols = int(np.min([tuple_length, MAX_COLS_PER_ROW]))
    nrows = int(np.ceil(tuple_length / MAX_COLS_PER_ROW))
    fig = plt.figure(figsize=(TILE_WIDTH * ncols, TILE_HEIGHT * nrows))
    grid = ImageGrid(
        fig, 111,  # similar to subplot(111)
        nrows_ncols=(nrows, ncols),
        axes_pad=TILE_PAD,  # pad between axes in inch.
    )
    # visualize context
    for ctx_idx in range(ctx_size):
        rgb = ctx_frames[n, ctx_idx]
        pose = ctx_poses[n, ctx_idx]
        caption = "ctx: %02d\nxyz:%s\nyp:%s" % \
            (ctx_idx + 1, pose[0:3], pose[3:])
        grid[ctx_idx].imshow(rgb)
        grid[ctx_idx].set_title(caption, loc='center')
    # visualize target
    rgb = tgt_frame[n]
    pose = tgt_pose[n]
    caption = "target\nxyz:%s\nyp:%s" % \
        (pose[0:3], pose[3:])
    grid[-1].imshow(rgb)
    grid[-1].set_title(caption, loc='center')
    plt.show()

Loop completed. >>> Context frames: (1, 5, 64, 64, 3) >>> Context poses: (1, 5, 7) >>> Target frame: (1, 64, 64, 3) >>> Target pose: (1, 7) image

'''run the view prediction'''

# visualize all data tuples in the batch
for n in range(batch_size):

    print(">>> Predictions:")
    # define image grid for predictions
    ncols = int(np.min([tuple_length, MAX_COLS_PER_ROW]))
    nrows = int(np.ceil(tuple_length / MAX_COLS_PER_ROW))
    fig = plt.figure(figsize=(TILE_WIDTH * ncols, TILE_HEIGHT * nrows))
    grid = ImageGrid(
        fig, 111,  # similar to subplot(111)
        nrows_ncols=(nrows, ncols),
        axes_pad=TILE_PAD,  # pad between axes in inch.
    )
    # load the scene context
    predictor.clear_context()
    for i in range(ctx_size):
        ctx_frame = ctx_frames[n, i]
        ctx_pose = ctx_poses[n, i]
        predictor.add_context_view(ctx_frame, ctx_pose)
    # render query
    query_pose = tgt_pose[n]
    pred_frame = predictor.render_query_view(query_pose)[0]
    caption = "query\nxyz:%s\nyp:%s" % \
        (query_pose[0:3], query_pose[3:])
    grid[0].imshow(pred_frame)
    grid[0].set_title(caption, loc='center')
    # re-render context (auto-encoding consistency)
    for ctx_idx in range(ctx_size):
        query_pose = ctx_poses[n, ctx_idx]
        pred_frame = predictor.render_query_view(query_pose)[0]
        caption = "ctx: %02d\nxyz:%s\nyp:%s" % \
            (ctx_idx + 1, query_pose[0:3], query_pose[3:])
        grid[ctx_idx + 1].imshow(pred_frame)
        grid[ctx_idx + 1].set_title(caption, loc='center')
    plt.show()

    print(">>> Ground truth:")
    # define image grid for predictions
    ncols = int(np.min([tuple_length, MAX_COLS_PER_ROW]))
    nrows = int(np.ceil(tuple_length / MAX_COLS_PER_ROW))
    fig = plt.figure(figsize=(TILE_WIDTH * ncols, TILE_HEIGHT * nrows))
    grid = ImageGrid(
        fig, 111,  # similar to subplot(111)
        nrows_ncols=(nrows, ncols),
        axes_pad=TILE_PAD,  # pad between axes in inch.
    )
    # query
    pose = tgt_pose[n]
    rgb = tgt_frame[n]
    caption = "query\nxyz:%s\nyp:%s" % \
        (pose[0:3], pose[3:])
    grid[0].imshow(rgb)
    grid[0].set_title(caption, loc='center')
    # context
    for ctx_idx in range(ctx_size):
        pose = ctx_poses[n, ctx_idx]
        rgb = ctx_frames[n, ctx_idx]
        caption = "ctx: %02d\nxyz:%s\nyp:%s" % \
            (ctx_idx + 1, pose[0:3], pose[3:])
        grid[ctx_idx + 1].imshow(rgb)
        grid[ctx_idx + 1].set_title(caption, loc='center')
    plt.show()

image

'''render a view interpolation trajectory'''

# query pose trajectory per dataset
# [[0, 0, 0, yaw, 0] for yaw in range(0, 360, 10)]
def _query_poses(num_poses=40, radius=3.0, height=2.5, angle=30.0):
    x = list(radius * np.sin(np.linspace(np.pi, -np.pi, num_poses)))
    y = list(radius * np.cos(np.linspace(np.pi, -np.pi, num_poses)))
    z = list(height * np.ones((num_poses, )))
    yaw = list(np.linspace(0.0, 360.0, num_poses))
    pitch = list(angle * np.ones((num_poses, )))
    poses = list(zip(x, y, z, yaw, pitch))
    return poses

QUERY_POSES = {
    'shepard_metzler_5_parts' : _query_poses(),
    'shepard_metzler_7_parts' : _query_poses(),
}

# generate query poses
query_poses = QUERY_POSES[dataset_name]
query_poses = [np.array(qp) for qp in query_poses]

# render corresponding views
print(">>> Rendering interpolation trajectory for %d query poses..." % (len(query_poses), ))
frame_buffer = []
for i, query_pose in enumerate(query_poses):
    pred_frame = predictor.render_query_view(query_pose)[0]
    frame_buffer.append(pred_frame)
    if (i+1) % 10 == 0:
        print("    %d / %d frames rendered." % ((i+1), len(query_poses)))

# show gif of view interpolation trajectory
gif_tmp_path = os.path.join(notebook_tmp_path, 'view_interpolation_preview.gif')
imageio.mimsave(gif_tmp_path, frame_buffer)
with open(gif_tmp_path, 'rb') as file:
    display(Image(file.read()))

>>> Rendering interpolation trajectory for 40 query poses... 10 / 40 frames rendered. 20 / 40 frames rendered. 30 / 40 frames rendered. 40 / 40 frames rendered. view_interpolation_preview