Offset bug in chain_demo.py? #16

Closed ZHAOZHIHAO closed 1 year ago

ZHAOZHIHAO commented 1 year ago

First, thanks for sharing your codes.

I tried your codes on my own long video, which is split into many short clips. I visualized the trajectory for a single point in tensor board. It tracks very well for every frame of the first clip, but the starting tracking point of the second clip has an offset about 5-10 pixels. And it seems that every frame of the second clip has the same amount of offset compared to where it should stick to. Is this a bug of this code and could you please provide a solution? It may just be one or two lines modification?


aharley commented 1 year ago

Sounds like an issue with the padding maybe. Try eliminating the padding lines to simplify the code, and see if the issue persists: https://github.com/aharley/pips/blob/main/demo.py#L44-L46

ZHAOZHIHAO commented 1 year ago

Hi Adam,

Thanks for your super quick response. Unfortunately, it seems not working. I removed these lines and also removed "-pad" in "return trajs_e-pad" of function "run_model". In tensorboard, I can see that after these modifications, there is no black padding to the video border, but the offsets are still there.


aharley commented 1 year ago

I see. Maybe you can paste your code here, and I can take a look?

ZHAOZHIHAO commented 1 year ago

Thanks for your efforts. Here is my code. It's modified from chain_demo.py. I changed input video file, tracking point location, and paddings as in the above comment.

import time
import numpy as np
import io
import os
from PIL import Image
import cv2
import saverloader
import imageio.v2 as imageio
from nets.pips import Pips
import utils.improc
import random
import glob
from utils.basic import print_, print_stats
import torch
from tensorboardX import SummaryWriter
import torch.nn.functional as F


def run_model(model, rgbs, N, sw):
    rgbs = rgbs.cuda().float() # B, S, C, H, W

    B, S, C, H, W = rgbs.shape
    rgbs_ = rgbs.reshape(B*S, C, H, W)
    H_, W_ = 360, 640
    rgbs_ = F.interpolate(rgbs_, (H_, W_), mode='bilinear')
    H, W = H_, W_
    rgbs = rgbs_.reshape(B, S, C, H, W)

    # try to pick a point on the dog, so we get an interesting trajectory
    # x = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 468
    # y = torch.randint(-10, 10, size=(1, N), device=torch.device('cuda')) + 118
    x = torch.ones((1, N), device=torch.device('cuda')) * 371.0
    y = torch.ones((1, N), device=torch.device('cuda')) * 95.0
    xy0 = torch.stack([x, y], dim=-1) # B, N, 2
    _, S, C, H, W = rgbs.shape

    trajs_e = torch.zeros((B, S, N, 2), dtype=torch.float32, device='cuda')
    for n in range(N):
        # print('working on keypoint %d/%d' % (n+1, N))
        cur_frame = 0
        done = False
        traj_e = torch.zeros((B, S, 2), dtype=torch.float32, device='cuda')
        traj_e[:,0] = xy0[:,n] # B, 1, 2  # set first position 
        feat_init = None
        while not done:
            end_frame = cur_frame + 8

            rgb_seq = rgbs[:,cur_frame:end_frame]
            S_local = rgb_seq.shape[1]
            rgb_seq = torch.cat([rgb_seq, rgb_seq[:,-1].unsqueeze(1).repeat(1,8-S_local,1,1,1)], dim=1)

            outs = model(traj_e[:,cur_frame].reshape(1, -1, 2), rgb_seq, iters=6, feat_init=feat_init, return_feat=True)
            preds = outs[0]
            vis = outs[2] # B, S, 1
            feat_init = outs[3]

            vis = torch.sigmoid(vis) # visibility confidence
            xys = preds[-1].reshape(1, 8, 2)
            traj_e[:,cur_frame:end_frame] = xys[:,:S_local]

            found_skip = False
            thr = 0.9
            si_last = 8-1 # last frame we are willing to take
            si_earliest = 1 # earliest frame we are willing to take
            si = si_last
            while not found_skip:
                if vis[0,si] > thr:
                    found_skip = True
                    si -= 1
                if si == si_earliest:
                    # print('decreasing thresh')
                    thr -= 0.02
                    si = si_last
            # print('found skip at frame %d, where we have' % si, vis[0,si].detach().item())

            cur_frame = cur_frame + si

            if cur_frame >= S:
                done = True
        trajs_e[:,:,n] = traj_e

    #pad = 50
    #rgbs = F.pad(rgbs.reshape(B*S, 3, H, W), (pad, pad, pad, pad), 'constant', 0).reshape(B, S, 3, H+pad*2, W+pad*2)
    #trajs_e = trajs_e + pad

    prep_rgbs = utils.improc.preprocess_color(rgbs)
    gray_rgbs = torch.mean(prep_rgbs, dim=2, keepdim=True).repeat(1, 1, 3, 1, 1)

    if sw is not None and sw.save_this:
        linewidth = 2

        for n in range(N):
            # print('visualizing kp %d' % n)
            kp_vis = sw.summ_traj2ds_on_rgbs('video_%d/kp_%d_trajs_e_on_rgbs' % (sw.global_step, n), trajs_e[0:1,:,n:n+1], gray_rgbs[0:1,:S], cmap='spring', linewidth=linewidth)

            # write to disk, in case that's more convenient
            kp_list = list(kp_vis.unbind(1))
            kp_list = [kp[0].permute(1,2,0).cpu().numpy() for kp in kp_list]
            kp_list = [Image.fromarray(kp) for kp in kp_list]
            out_fn = './chain_out_%d.gif' % sw.global_step
            kp_list[0].save(out_fn, save_all=True, append_images=kp_list[1:])
            print('saved %s' % out_fn)

        sw.summ_traj2ds_on_rgb('outputs/trajs_e_on_rgb', trajs_e[0:1], prep_rgbs[0:1,0], cmap='spring')
        sw.summ_traj2ds_on_rgb('outputs/trajs_e_on_rgb2', trajs_e[0:1], torch.mean(prep_rgbs[0:1], dim=1), cmap='spring')

    return trajs_e

def main():

    # the idea in this file is to chain together pips from a long sequence, and return some visualizations

    exp_name = '00' # (exp_name is used for logging notes that correspond to different runs)

    init_dir = 'reference_model'

    ## choose hyps
    B = 1
    S = 50
    N = 1 # number of points to track

    filenames = glob.glob('./woman_imgs/*.jpg')
    filenames = sorted(filenames)
    print('filenames', filenames)
    max_iters = len(filenames)//(S//2)-1 # run slightly overlapping subseqs

    log_freq = 1 # when to produce visualizations 

    ## autogen a name
    model_name = "%02d_%d_%d" % (B, S, N)
    model_name += "_%s" % exp_name
    import datetime
    model_date = datetime.datetime.now().strftime('%H:%M:%S')
    model_name = model_name + '_' + model_date
    print('model_name', model_name)

    log_dir = 'logs_chain_demo'
    writer_t = SummaryWriter(log_dir + '/' + model_name + '/t', max_queue=10, flush_secs=60)

    global_step = 0

    model = Pips(stride=4).cuda()
    parameters = list(model.parameters())
    if init_dir:
        _ = saverloader.load(init_dir, model)
    global_step = 0

    while global_step < max_iters:

        read_start_time = time.time()

        global_step += 1

        sw_t = utils.improc.Summ_writer(

            rgbs = []
            for s in range(S):
                fn = filenames[(global_step-1)*S//2+s]
                if s==0:
                    print('start frame', fn)
                im = imageio.imread(fn)
                im = im.astype(np.uint8)
            rgbs = torch.stack(rgbs, dim=0).unsqueeze(0) # 1, S, C, H, W

            read_time = time.time()-read_start_time
            iter_start_time = time.time()

            with torch.no_grad():
                trajs_e = run_model(model, rgbs, N, sw_t)

            iter_time = time.time()-iter_start_time
            print('%s; step %06d/%d; rtime %.2f; itime %.2f' % (
                model_name, global_step, max_iters, read_time, iter_time))
        except FileNotFoundError as e:
            print('error', e)


if __name__ == '__main__':
aharley commented 1 year ago

Maybe I should also ask: why are you chopping your original video into clips, instead of feeding the full list of frames into the chaining? (Or have I misunderstood this?)

ZHAOZHIHAO commented 1 year ago

I convert the video into 616 frames and input these frames into the network. In the tensorboard visualization, the video/ 616 frames are split into several clips and the trajectories are shown for each clip separately. Like in this figure,

Screenshot 2023-01-20 at 7 57 14 PM
aharley commented 1 year ago

OK now I think I understand. Can you try setting S=100, to see if this tracks longer for you? Then, maybe setting S=616 will give the full output.

ZHAOZHIHAO commented 1 year ago

Setting S=616 works! The point is generally well tracked for the entire video, with only shifting several pixels for a few frames. Thanks for your patience on this idiot:)