facebookresearch / vggsfm

VGGSfM: Visual Geometry Grounded Deep Structure From Motion
Other
849 stars 55 forks source link

Questions about training the tracker T #21

Open qsisi opened 2 months ago

qsisi commented 2 months ago

Hello! @jytime I have several questions about the multi-stage training, specifically, the tracker.

In my understanding, you trained the tracker on kubric first, then finetuned it on Co3D or MegaDepth depending on the test dataset. CoTracker trained its model on kubric by setting the sliding_window=8, with a total of 24 frames per sequence. In the paper which said:

differently from [27, 39], our tracker does not assume temporal continuity. Therefore, we avoid the sliding window approach and, instead, attend to all the input frames together.

So does vggsfm trained its tracker with sliding_window=24 for kubric?(which might cause huge GPU memory consumption?) Also, what about in MegaDepth or Co3D?

Looking forward to the releasing of the training code :)

jytime commented 2 months ago

Hi @qsisi ,

I am not sure if I understand the question correctly. For all the datasets, we train a tracker with the frame number of range(3,31). It means, we randomly sample 3-31 frames for a scene and feed them directly to the tracker. It can fit well in a 80 GB A100 GPU (with bf16), with the number of tracks as 512.

qsisi commented 2 months ago

Thanks for your prompt reply. So for every batch in training, the video length is a random number ∈[3,31], for kubric/MegaDepth/Co3D, right?

Also, in my trials, the tracker performance varies between your previous released model vggsfm_v102.bin and the current model vggsfm_v2_0_0.bin, and it seems like the previous model is better than the current one :) May I ask what happened here? Do these two models' tracker part were trained with different strategies?

jytime commented 2 months ago

It is the same for kubric/MegaDepth/Co3D. However, it should be noted that in the original kubric dataset, the videos may only have 24 frames, so it should be [3,24].

Can you elaborate about how did you evaluate the performance of two trackers? In our internal experiments, the new one should be more accurate and generalize better.

qsisi commented 2 months ago

Thanks for your reply.

It is the same for kubric/MegaDepth/Co3D. However, it should be noted that in the original kubric dataset, the videos may only have 24 frames, so it should be [3,24].

  1. So the tracker is trained with a random length of video input, for kubric is [3, 24] and [3, 31] for MegaDepth, right?
  2. In my understanding, the tracker of vggsfm is analogic to Cotracker, so is the above training strategy better than the CoTracker's "unrolled window training"?

Can you elaborate about how did you evaluate the performance of two trackers? In our internal experiments, the new one should be more accurate and generalize better.

  1. I randomly test the coarse tracker on this data: https://cvg-data.inf.ethz.ch/local-feature-evaluation-schoenberger2017/South-Building.zip, and count the track_length of predicted tracks. Here is the code:
    
    from vggsfm.models.track_modules.base_track_predictor import BaseTrackerPredictor
    from vggsfm.models.track_modules.blocks import BasicEncoder
    from gluefactory.models.extractors.superpoint_open import SuperPoint
    from gluefactory.models.extractors.sift import SIFT
    import torch.nn.functional as F
    from omegaconf import DictConfig, OmegaConf
    import hydra
    import torch
    import numpy as np
    import cv2 as cv
    from collections import Counter
    import matplotlib.pyplot as plt
    import copy

def get_query_points(superpoint, sift, query_image, max_query_num=4096):

pred_sp = superpoint({"image": query_image})["keypoints"]
pred_sift = sift({"image": query_image})["keypoints"]

query_points = torch.cat([pred_sp, pred_sift], dim=1)
query_points = pred_sift

if query_points.shape[1] > max_query_num:
    random_point_indices = torch.randperm(query_points.shape[1])[:max_query_num]
    query_points = query_points[:, random_point_indices, :]

return query_points

@hydra.main(config_path="cfgs/", config_name="demo") def main(cfg: DictConfig): OmegaConf.set_struct(cfg, False) fnet_v1 = BasicEncoder(cfg=cfg).cuda() tracker_v1 = BaseTrackerPredictor(cfg=cfg).cuda() fnet_v2 = BasicEncoder(cfg=cfg).cuda() tracker_v2 = BaseTrackerPredictor(cfg=cfg).cuda() ckpt_v1 = torch.load("/data/vggsfm/vggsfm_v102.bin") ckpt_v2 = torch.load("/data/vggsfm_v2/vggsfm/vggsfm_v2_0_0.bin")

fnet_dict_from_ckpt_v1 = {key.replace("track_predictor.coarse_fnet.", "") : val for key, val in ckpt_v1.items() if "track_predictor.coarse_fnet" in key}
assert fnet_dict_from_ckpt_v1.keys() == fnet_v1.state_dict().keys()
fnet_v1.load_state_dict(fnet_dict_from_ckpt_v1, strict=True)

tracker_dict_from_ckpt_v1 = {key.replace("track_predictor.coarse_predictor.", "") : val for key, val in ckpt_v1.items() if "track_predictor.coarse_predictor" in key}
assert tracker_dict_from_ckpt_v1.keys() == tracker_v1.state_dict().keys()
tracker_v1.load_state_dict(tracker_dict_from_ckpt_v1, strict=True)

fnet_dict_from_ckpt_v2 = {key.replace("track_predictor.coarse_fnet.", "") : val for key, val in ckpt_v2.items() if "track_predictor.coarse_fnet" in key}
assert fnet_dict_from_ckpt_v2.keys() == fnet_v2.state_dict().keys()
fnet_v2.load_state_dict(fnet_dict_from_ckpt_v2, strict=True)

tracker_dict_from_ckpt_v2 = {key.replace("track_predictor.coarse_predictor.", "") : val for key, val in ckpt_v2.items() if "track_predictor.coarse_predictor" in key}
assert tracker_dict_from_ckpt_v2.keys() == tracker_v2.state_dict().keys()
tracker_v2.load_state_dict(tracker_dict_from_ckpt_v2, strict=True)

superpoint = SuperPoint({"nms_radius": 4, "force_num_keypoints": True}).cuda().eval()
sift = SIFT({}).cuda().eval()
fnet_v1.eval()
tracker_v1.eval()
fnet_v2.eval()
tracker_v2.eval()

imgs = []
imgs.append(cv.resize(cv.imread("/data/South-Building/images/P1180141.JPG"), (1920, 1080))[None, ...])
imgs.append(cv.resize(cv.imread("/data/South-Building/images/P1180142.JPG"), (1920, 1080))[None, ...])
imgs.append(cv.resize(cv.imread("/data/South-Building/images/P1180143.JPG"), (1920, 1080))[None, ...])
imgs.append(cv.resize(cv.imread("/data/South-Building/images/P1180144.JPG"), (1920, 1080))[None, ...])
imgs.append(cv.resize(cv.imread("/data/South-Building/images/P1180145.JPG"), (1920, 1080))[None, ...])
imgs.append(cv.resize(cv.imread("/data/South-Building/images/P1180146.JPG"), (1920, 1080))[None, ...])
imgs.append(cv.resize(cv.imread("/data/South-Building/images/P1180147.JPG"), (1920, 1080))[None, ...])
imgs.append(cv.resize(cv.imread("/data/South-Building/images/P1180148.JPG"), (1920, 1080))[None, ...])
imgs_numpy = np.concatenate(imgs)

video = torch.from_numpy(imgs_numpy).unsqueeze(0).permute(0, 1, 4, 2, 3).float().cuda() / 255.

query_points = get_query_points(superpoint, sift, video[0, 0:1, ...]).cuda()

original_h, original_w = 1080, 1920
new_h, new_w = 384, 512
stride = 4.

video = F.interpolate(video, size=(3, new_h, new_w))
query_points[:, :, 0] *= new_w / original_w
query_points[:, :, 1] *= new_h / original_h

## version1
with torch.no_grad():
    fmaps = fnet_v1(video.flatten(0, 1)).unsqueeze(0)

    coord_pred_list, vis_pred = tracker_v1(query_points / stride, fmaps)
    coord_pred = coord_pred_list[-1]

    # visualize
    imgs_to_show = copy.deepcopy(imgs)
    imgs_to_show = [cv.resize(img[0], (new_w, new_h)) for img in imgs_to_show]
    for i in range(coord_pred.shape[2]):
        if (vis_pred[0, :, i] > 0.05).sum().item() == len(imgs):
            for j in range(len(imgs_to_show)):
                if vis_pred[0,j,i].item() > 0.05:
                    cv.circle(imgs_to_show[j], (int(coord_pred[0,j,i,0]), int(coord_pred[0,j,i,1])), 3, (0, 0, 255), -1)

    cat_img = cv.hconcat(imgs_to_show)
    cat_img_left, cat_img_right = cat_img[:, :cat_img.shape[1]//2], cat_img[:, cat_img.shape[1]//2:]
    cat_img = cv.vconcat([cat_img_left, cat_img_right])
    cv.imwrite(f"track_full_video_version_1.png", cat_img)

    track_length = [(vis_pred[0,:,i] > 0.05).sum().item() for i in range(vis_pred.shape[2])]
    track_length = np.array(track_length)
    ct = Counter(track_length)

    track_count_version_1 = [ct[i] for i in range(1, len(imgs) + 1)]

## version2
with torch.no_grad():
    fmaps = fnet_v2(video.flatten(0, 1)).unsqueeze(0)

    coord_pred_list, vis_pred = tracker_v2(query_points / stride, fmaps)
    coord_pred = coord_pred_list[-1]

    # visualize
    imgs_to_show = copy.deepcopy(imgs)
    imgs_to_show = [cv.resize(img[0], (new_w, new_h)) for img in imgs_to_show]
    for i in range(coord_pred.shape[2]):
        if (vis_pred[0, :, i] > 0.05).sum().item() == len(imgs):
            for j in range(len(imgs_to_show)):
                if vis_pred[0,j,i].item() > 0.05:
                    cv.circle(imgs_to_show[j], (int(coord_pred[0,j,i,0]), int(coord_pred[0,j,i,1])), 3, (0, 0, 255), -1)

    cat_img = cv.hconcat(imgs_to_show)
    cat_img_left, cat_img_right = cat_img[:, :cat_img.shape[1]//2], cat_img[:, cat_img.shape[1]//2:]
    cat_img = cv.vconcat([cat_img_left, cat_img_right])
    cv.imwrite(f"track_full_video_version_2.png", cat_img)

    track_length = [(vis_pred[0,:,i] > 0.05).sum().item() for i in range(vis_pred.shape[2])]
    track_length = np.array(track_length)
    ct = Counter(track_length)
    track_count_version_2 = [ct[i] for i in range(1, len(imgs) + 1)]

x = [i for i in range(1, len(imgs) + 1)]
plt.plot(x, track_count_version_1, color='b', label='vgg_v1')
plt.plot(x, track_count_version_2, color='r', label='vgg_v2')
plt.xlabel("Track Length")
plt.ylabel("Count")
plt.legend()
plt.savefig("vgg_v1&v2_comparison.png")

if name == "main": main()


for images P1180141 ~ P1180148, the track count curve:
![vgg_v1 v2_comparison](https://github.com/facebookresearch/vggsfm/assets/44374058/31ddf121-20bd-4ef5-b652-d32c8dc82203)
for P1180181 ~ P1180188:
![vgg_v1 v2_comparison](https://github.com/facebookresearch/vggsfm/assets/44374058/f95f75d3-19f0-4276-b986-8c3ab45b0739)
for P1180201 ~ P1180208:
![vgg_v1 v2_comparison](https://github.com/facebookresearch/vggsfm/assets/44374058/90e4b7fd-588a-4899-9fda-13423976d5dc)
and for other image inputs, the difference between v2.0 and v1.0 is small. 
But with the above cases, it seems like the v2.0 model is not "completely outperform" v1.0 model.

So would you mind sharing some information about the "internal experiments" comparing these two trackers? It definitely helps a lot.

Looking forward to your reply. 
jytime commented 2 months ago

Hi,

  1. It is right.
  2. I did not compare these two kinds of training strategies, because the "unrolled window training" of cotracker assumes a time continuity. It almost cannot be satisfied in SfM, as SfM inputs unordered images. So it is not comparable in the setting of SfM. However, feel free to replace our coarse tracker by cotracker (or any other video tracker) if you are dealing with video inputs. Note that in this case, you still need to apply our fine tracker.
  3. I cannot share the details, but it is about the generalisation ability, instead of the performance on a single dataset.
qsisi commented 2 months ago

Thanks for your reply.

  1. Do you have any plans to release the training code of the tracker? That would resolve most of the questions here.
  2. Following the metric computation of CoTracker, I test the performance between vgg_v1 and vgg_v2 on tap_vid_davis_first , here are the results: v1.0: image v2.0: image and it seems like the v1.0 outperforms the v2.0 in terms of 'occlusion_accuracy' and 'average_jaccard'. I agree that v2.0 may have a stronger generalization ability than v1.0 in your "internal experiments". I'm just curious about what just happed in the training during these two trackers. As I said, looking forward to the release of training scripts, which would help a lot with the above questions.
jytime commented 2 months ago

Hi @qsisi ,

  1. Yes we do plan to release it, but it is not our top priority in the near future. We focus on providing more powerful inference scripts by now. If you really want to see how the training is conducted, you can refer to the dirty_train branch of this repo, https://github.com/facebookresearch/vggsfm/tree/dirty_train/dirty_train. The dirty_train branch provides some original training files, but not cleaned.
  2. I expect the performance may differ but not in such a huge degree. And in my own testing, the tracker (both v1 and v2) should work at least not worse than cotracer. During testing, did you crop and resize the input video using the way of our dataloader (https://github.com/facebookresearch/vggsfm/blob/main/vggsfm/datasets/demo_loader.py)? Basically we will pad the image to square and resize it to 1024x1024. The model was only trained in this manner. If you feed in an image not following this style, the performance will drop a lot.
jytime commented 2 months ago

By the way, among the checkpoints v1.0, v1.1, and v1.2, which one works best for you? I can also put it into the Readme file in case someone else may need it. I will also run another comparison between v1.x and v2.0 checkpoint.

qsisi commented 2 months ago

@jytime

  1. Thanks for providing the dirty_train branch! Which would be so helpful. However, I noticed that the code about MegaDepth is not provided, which is the part I'm most interested in, could you also upload that part? :)
  2. Directly on the tap_vid_davis_first, which I believe resized the image to 256x256. I think it may drop performance on this resolution, but it's still strange to have a big gap between v1.2 and v2.0.
  3. I only tested the vggsfm_v102.bin and vggsfm_v2_0_0.bin. For my own data (autonomous driving data), during a couple of trails(only test the performance of coarse tracker), the vggsfm_v102.bin is better than vggsfm_v2_0_0.bin. Looking forward to seeing your comparison between v1.x and v2.0.
jytime commented 2 months ago

Hi @qsisi ,

I have updated the dataset file for MegaDepth.

qsisi commented 2 months ago

@jytime

Thanks for your help! Also, when will you upload the "scene_info.npz" of megadepth? As well as the implementation of "rawcamera_to_track" function! That would be so helpful to understand the code.

jytime commented 2 months ago

Hi,

The "rawcamera_to_track" has been uploaded to https://github.com/facebookresearch/vggsfm/blob/dirty_train/dirty_train/dataset_util.py

Regarding scene_info.npz, they were generated by previous works (such as this). If I remember it correctly, they were downloaded automatically if you download megadepth from glue-factory, or if you want, you can get them by your own processing as guided here https://github.com/mihaidusmanu/d2-net

qsisi commented 2 months ago

Thanks for your update.

  1. Looks like the visibility flag is generated by judging the difference between proj_depth and depth_by_img. However, during my trails on MegaDepth, I observe that there are some cases where a point is actually visible in one view, but it get assigned as "invisible" due to the extrinsics's error and incomplete depths. For example, point 0 in image A is projected to image B at point 1 like follows: point 0 in image A: image point 1 in image B (projected from point 0 in image A): image but point 1 is not perfectly accurate due to the extrinsics error, then when sampling point 1's depth in image B, it gets zero depth: example then depth_by_img is 0, visibility flag is assigned as invisible, but I think the point 1 is actually visible in image B. Does these inaccurate ground-truth supervisions harm the training?
  2. Have to turn on the return_track flag in MegaDepth training? Which will use top-k to sample tracks with most visibilities. In this case, the question 1 is not an issue anymore because in my understanding the selected_visibility is mostly True. There are almost no invisible points in selected_visibility.
  3. The balanced_ce_loss related question: https://github.com/facebookresearch/co-tracker/issues/94, why use this balanced_ce_loss instead of nn.BCELoss()? It's weird for me that nn.BCELoss(gt, gt) = 0, but in your implementation balanced_ce_loss(gt, gt) != 0 :)

Thanks again for your update and patience!

jytime commented 2 months ago
  1. Yes it is expected. This is due to the inaccurate ground truth dense depths maps from megadepth. Usually we cannot correct them, so it is simplest to drop them, i.e., set as non-visible.
  2. You can set return_track=False, but it will lead to many non-visible sampled tracks. I think it will not change the performance a lot but may affect the convergence time.
  3. I did not think over this, as it seems not an important choice for the entire system.
qsisi commented 2 months ago

So you set return_track=True, select tracks with the most visibilities using top-k. i.e. nearly no non-visible tracks were sampled during training?

jytime commented 2 months ago

We never sample tracks that are invisible over all the frames, but it is possible that tracks are invisible to some of the frames, e.g., a track is visible to 5 frames, and invisible to 2 frames.

qsisi commented 2 months ago

Have you found the above inaccurate "visibility" flag supervision signals harm the training? Or should we just sample tracks that is all visible to avoid the inaccurate "visibility" flags? Thanks for your advice!

jytime commented 2 months ago

No. Such inaccurate "visibility" flag is not a problem for our training.

qsisi commented 2 months ago

Thanks for your reply!

I trained a vanilla tracker on MegaDepth, but its performance is far from the vgg :(

So I was wondering could you provide the data configuration yaml(whether to do cropping, RandomErasing, etc...) for MegaDepth :) It looks like the vggsfm_v5.yaml in branch 'dirty_train' is configured for Co3D.

Thanks for your help!

jytime commented 2 months ago
    node_num = 1
    gpu_num = 8

    accelerate_args = {
        "num_machines": node_num,
        "multi_gpu": True,
        "num_processes": gpu_num * node_num,  # 8 gpus requested
        "num_cpu_threads_per_process": 12,
    }

    hydra_config = "../cfgs/vggsfm_v5.yaml"
    base_conf = OmegaConf.load(hydra_config)

    # Common params

    base_conf.seed = 100866

    grid_param = {
        "load_camera": [False],
        "adapad": [False],
        "pre_factor": [2],
        "mixed_precision": ["bf16"],
        "train.img_size": [1024],
        "rot_aug": [True],
        "inside_shuffle": [True],
        "MODEL.ENCODER.stride": [4],
        "clip_trackL": [512],
        "train.mixset": ["m"],    # YOU CAN ALSO SET "km", which uses kubric and megadepth
        "train.erase_aug": [True],
        "repeat_mix": [1],
        "batch_size": [4],  # "batch_size": [4, 2,],
        "train.track_num": [512],
        "dynamix": [True],
        "train.max_images": [64],
        "train.lr": [0.0001],
        "train.len_train": [4096],
    }

Please see this config, which trains tracker on megadepth with 8 GPUs (v2 ckpt was trained by 8 GPUs). Basically it inherits from the default config of vggsfm_v5.yaml, and uses grid_param to change some flags. We use the same config for training on kubric, or you can directly combine kubric and megadeph for training. If you still find it hard for training, please try to freeze the fine tracker in the beginning of training. Only with coarse tracker you should already be able to achieve some results that look good in human eyes.

wzds2015 commented 2 months ago
  1. Do you have plan to release model trained on arbitrary image shape? Squared image assumption isn't useful for most cases.
  2. Do we have any workaround using the current squared model? I guess padding image to square won't work.
jytime commented 2 months ago

Hi @wzds2015 , I’m not sure if I understand your question correctly. Our inference process, as demonstrated in our demo.py file, supports images of any shape. During inference, images are padded to a square and resized to a fixed resolution.

You can check our Hugging Face demo at this link for a demonstration. It works quite well on images with different shapes.

wzds2015 commented 2 months ago

Hi @wzds2015 , I’m not sure if I understand your question correctly. Our inference process, as demonstrated in our demo.py file, supports images of any shape. During inference, images are padded to a square and resized to a fixed resolution.

You can check our Hugging Face demo at this link for a demonstration. It works quite well on images with different shapes.

Hi Jianyuan, Thanks for the response. I am reading the source code. The current source code in github repo hasn't used the padding you mentioned, right? I saw in dataloader, it forces to use crop_longest. If not the case, can you provide some code pointers in the github repo? I want to make sure correct use of the inverted camera poses for my downstreaming works. Basically I want to micmi the colmap reconstruction outputs and work on other tasks.

wzds2015 commented 2 months ago

Hi @wzds2015 , I’m not sure if I understand your question correctly. Our inference process, as demonstrated in our demo.py file, supports images of any shape. During inference, images are padded to a square and resized to a fixed resolution. You can check our Hugging Face demo at this link for a demonstration. It works quite well on images with different shapes.

Hi Jianyuan, Thanks for the response. I am reading the source code. The current source code in github repo hasn't used the padding you mentioned, right? I saw in dataloader, it forces to use crop_longest. If not the case, can you provide some code pointers in the github repo? I want to make sure correct use of the inverted camera poses for my downstreaming works. Basically I want to micmi the colmap reconstruction outputs and work on other tasks.

Ah I see what u meant. The crop function actually can also play as padding function.

jytime commented 2 months ago

Hi @wzds2015 , I’m not sure if I understand your question correctly. Our inference process, as demonstrated in our demo.py file, supports images of any shape. During inference, images are padded to a square and resized to a fixed resolution. You can check our Hugging Face demo at this link for a demonstration. It works quite well on images with different shapes.

Hi Jianyuan, Thanks for the response. I am reading the source code. The current source code in github repo hasn't used the padding you mentioned, right? I saw in dataloader, it forces to use crop_longest. If not the case, can you provide some code pointers in the github repo? I want to make sure correct use of the inverted camera poses for my downstreaming works. Basically I want to micmi the colmap reconstruction outputs and work on other tasks.

Ah I see what u meant. The crop function actually can also play as padding function.

Yes. For example, you have an image of size (1080, 720). When you use crop_longest, the bbox would have a size of (1080, 1080), which actually pads zeros to the shorter size.

https://github.com/facebookresearch/vggsfm/blob/9263351e7591a45139944f9a829ebd872bd6760e/vggsfm/datasets/demo_loader.py#L205C17-L205C21

wzds2015 commented 2 months ago

Hi @wzds2015 , I’m not sure if I understand your question correctly. Our inference process, as demonstrated in our demo.py file, supports images of any shape. During inference, images are padded to a square and resized to a fixed resolution. You can check our Hugging Face demo at this link for a demonstration. It works quite well on images with different shapes.

Hi Jianyuan, Thanks for the response. I am reading the source code. The current source code in github repo hasn't used the padding you mentioned, right? I saw in dataloader, it forces to use crop_longest. If not the case, can you provide some code pointers in the github repo? I want to make sure correct use of the inverted camera poses for my downstreaming works. Basically I want to micmi the colmap reconstruction outputs and work on other tasks.

Ah I see what u meant. The crop function actually can also play as padding function.

Yes. For example, you have an image of size (1080, 720). When you use crop_longest, the bbox would have a size of (1080, 1080), which actually pads zeros to the shorter size.

https://github.com/facebookresearch/vggsfm/blob/9263351e7591a45139944f9a829ebd872bd6760e/vggsfm/datasets/demo_loader.py#L205C17-L205C21

many thanks Jianyuan. Is it okay to connect on WeChat? my email: wz927@nyu.edu

qsisi commented 2 months ago

Hi @jytime

https://github.com/facebookresearch/vggsfm/blob/dirty_train/dirty_train/megadepthV2.py#L85

why filter out these scenes?

jytime commented 1 month ago

Hi @wzds2015 mine is JianyuanJay

jytime commented 1 month ago

Hi @qsisi it is because some megadepth scenes (1) are used as validation set and (2) may have a bad quality. We just follow a common practice, e.g., see here https://github.com/cvg/glue-factory/tree/main/gluefactory/datasets/megadepth_scene_lists

qsisi commented 1 month ago

Sorry to bother you again, looks like you are using mixed dataset (kubric and megadepth). Could you kindly release the implementation of kubric dataset?

Thanks for your help!

jytime commented 1 month ago

Hey I uploaded the code for imc, re10k, and kubric.

kavehsfv commented 1 month ago

Thanks for your reply.

1. Do you have any plans to release the training code of the tracker? That would resolve most of the questions here.

2. Following the [metric computation](https://github.com/facebookresearch/co-tracker/blob/9ed05317b794cd177674e681321780614a65e073/cotracker/evaluation/core/evaluator.py#L35) of CoTracker, I test the performance between vgg_v1 and vgg_v2 on [tap_vid_davis_first](https://github.com/facebookresearch/co-tracker/blob/main/cotracker/datasets/tap_vid_datasets.py#L136) , here are the results:
   v1.0:
   ![image](https://private-user-images.githubusercontent.com/44374058/347295544-e1e6562c-ceb2-41b0-ad64-cddec0fc1306.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjMxNzUxOTksIm5iZiI6MTcyMzE3NDg5OSwicGF0aCI6Ii80NDM3NDA1OC8zNDcyOTU1NDQtZTFlNjU2MmMtY2ViMi00MWIwLWFkNjQtY2RkZWMwZmMxMzA2LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA4MDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwODA5VDAzNDEzOVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTkyZmNjZTk2NDQ4YzMyODNkMWExYWRlNDVkMzQ5ZmE3OGMxMWUyZmIxOWU5YjA2NmVhY2MyY2U4Nzc5MDZlNTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.J0QqY7I_UJbErAz3sfgAd7yV-N8jW0VtX_A5JiwklGY)
   v2.0:
   ![image](https://private-user-images.githubusercontent.com/44374058/347295600-2d9262fd-f4f0-4a8e-ab0b-6e9d906c0f6f.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjMxNzUxOTksIm5iZiI6MTcyMzE3NDg5OSwicGF0aCI6Ii80NDM3NDA1OC8zNDcyOTU2MDAtMmQ5MjYyZmQtZjRmMC00YThlLWFiMGItNmU5ZDkwNmMwZjZmLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA4MDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwODA5VDAzNDEzOVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTk0MjBjZjJjOGRkYjJiMGY0ZmZmMTNmZDcwNzc3YjI0Yzk5NjhjNWYyNGNhNzI2NDkwYTU2Y2JkMDQxNGU3YmYmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.pYfQwFlcqEfwWDoZnIXM5sr9rNtr-bX5GdEwKZJPPVQ)
   and it seems like the v1.0 outperforms the v2.0 in terms of 'occlusion_accuracy' and 'average_jaccard'. I agree that v2.0 may have a stronger generalization ability than v1.0 in your "internal experiments". I'm just curious about what just happed in the training during these two trackers. As I said, looking forward to the release of training scripts, which would help a lot with the above questions.

Hi @qsisi Would you please share your code on how to evaluate both trackers? I need to compare this tracker with CoTracker in my research, and your code would be a huge help. Thank you in advance!

qsisi commented 1 month ago

Thanks for your reply.

1. Do you have any plans to release the training code of the tracker? That would resolve most of the questions here.

2. Following the [metric computation](https://github.com/facebookresearch/co-tracker/blob/9ed05317b794cd177674e681321780614a65e073/cotracker/evaluation/core/evaluator.py#L35) of CoTracker, I test the performance between vgg_v1 and vgg_v2 on [tap_vid_davis_first](https://github.com/facebookresearch/co-tracker/blob/main/cotracker/datasets/tap_vid_datasets.py#L136) , here are the results:
   v1.0:
   ![image](https://private-user-images.githubusercontent.com/44374058/347295544-e1e6562c-ceb2-41b0-ad64-cddec0fc1306.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjMxNzUxOTksIm5iZiI6MTcyMzE3NDg5OSwicGF0aCI6Ii80NDM3NDA1OC8zNDcyOTU1NDQtZTFlNjU2MmMtY2ViMi00MWIwLWFkNjQtY2RkZWMwZmMxMzA2LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA4MDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwODA5VDAzNDEzOVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTkyZmNjZTk2NDQ4YzMyODNkMWExYWRlNDVkMzQ5ZmE3OGMxMWUyZmIxOWU5YjA2NmVhY2MyY2U4Nzc5MDZlNTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.J0QqY7I_UJbErAz3sfgAd7yV-N8jW0VtX_A5JiwklGY)
   v2.0:
   ![image](https://private-user-images.githubusercontent.com/44374058/347295600-2d9262fd-f4f0-4a8e-ab0b-6e9d906c0f6f.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MjMxNzUxOTksIm5iZiI6MTcyMzE3NDg5OSwicGF0aCI6Ii80NDM3NDA1OC8zNDcyOTU2MDAtMmQ5MjYyZmQtZjRmMC00YThlLWFiMGItNmU5ZDkwNmMwZjZmLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA4MDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwODA5VDAzNDEzOVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTk0MjBjZjJjOGRkYjJiMGY0ZmZmMTNmZDcwNzc3YjI0Yzk5NjhjNWYyNGNhNzI2NDkwYTU2Y2JkMDQxNGU3YmYmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.pYfQwFlcqEfwWDoZnIXM5sr9rNtr-bX5GdEwKZJPPVQ)
   and it seems like the v1.0 outperforms the v2.0 in terms of 'occlusion_accuracy' and 'average_jaccard'. I agree that v2.0 may have a stronger generalization ability than v1.0 in your "internal experiments". I'm just curious about what just happed in the training during these two trackers. As I said, looking forward to the release of training scripts, which would help a lot with the above questions.

Hi @qsisi Would you please share your code on how to evaluate both trackers? I need to compare this tracker with CoTracker in my research, and your code would be a huge help. Thank you in advance!

https://github.com/facebookresearch/vggsfm/issues/21#issuecomment-2214074932 Here's how I tested them, and only the coarse part of the tracker is included.