jchengai / forecast-mae

[ICCV'2023] Forecast-MAE: Self-supervised Pre-training for Motion Forecasting with Masked Autoencoders
https://arxiv.org/pdf/2308.09882.pdf
154 stars 16 forks source link

what is the computational time and the test GPU devices? #6

Closed teddyluo closed 9 months ago

teddyluo commented 9 months ago

Hi, I am very interested in this method. I have two questions:

1) As we know, SSL-lanes is quite fast and computationally efficient, can you provide the computational cost and the dest GPU device for more reference?

2) you said that MPNP will be open source. But till now, it has not been yet. Can you give more materials on MPNP (Github: https://jchengai.github.io/mpnp/)

Many thanks.

jchengai commented 9 months ago

Hi, thanks for your interest in our work.

  1. Some preliminary results: the uncompiled model runs about 115FPS (batch_size=1, preprocessed data) on RTX4080.

  2. Since Lyft5 dataset is not maintained, I'm considering a re-implementation based on nuPlan.

teddyluo commented 9 months ago

Hi, thanks for your interest in our work.

  1. Some preliminary results: the uncompiled model runs about 115FPS (batch_size=1, preprocessed data) on RTX4080.
  2. Since Lyft5 dataset is not maintained, I'm considering a re-implementation based on nuPlan.

Many thanks for your reply. I have waited it for 2 years. Looking forward to it.

teddyluo commented 9 months ago

By the way, I noted that the figures in the paper are very nice. Could you mind providing the code of how to visualize the results one by one using SSL-Lanes model and Forecase-MAE model, respectively? Thanks.

jchengai commented 9 months ago

The code is a little bit messy..., Just for your refernce:

import matplotlib.pyplot as plt

save_path = "./qualitative_for_suppl"
Path(save_path).mkdir(parents=True, exist_ok=True)
row = 4

visual_group = [[881, 7407, 4098, 4971], [5061, 6601, 7508, 94]]

for step in tqdm(range(len(visual_group))):
    save_file = Path(save_path) / f"{step}.pdf"
    visualize_scene = visual_group[step]
    n = len(visualize_scene)
    col = 3

    fig, axes = plt.subplots(
        n,
        col,
        figsize=(5 * col, 5 * n),
        gridspec_kw={"wspace": 0, "hspace": 0},
    )
    plt.rcParams["font.family"] = ["Serif"]
    plt.rcParams["font.serif"] = ["Times New Roman"]
    plt.rcParams["mathtext.default"] = "regular"
    plt.rcParams["xtick.direction"] = "in"
    plt.rcParams["ytick.direction"] = "in"
    plt.rcParams["xtick.labelsize"] = 14
    plt.rcParams["ytick.labelsize"] = 12
    plt.rcParams["pdf.fonttype"] = 42
    plt.rcParams["ps.fonttype"] = 42
    plt.rcParams["text.usetex"] = True
    plt.rcParams[
        "text.latex.preamble"
    ] = r"\makeatletter \newcommand*{\rom}[1]{\expandafter\@slowromancap\romannumeral #1@} \makeatother"

    for i in range(n):
        for j in range(col):
            axes[i][j].get_xaxis().set_visible(False)
            axes[i][j].get_yaxis().set_visible(False)

    with torch.no_grad():
        for i, scene in enumerate(visualize_scene):
            data = dataset[scene]
            scene_id = data["scenario_id"]

            ssl_data = None
            for ssl_data_tmp in ssl_dataset:
                if ssl_data_tmp["scene_id"] == scene_id:
                    ssl_data = ssl_data_tmp
                    break

            scene_file = data_root / scene_id / ("scenario_" + scene_id + ".parquet")
            map_file = data_root / scene_id / ("log_map_archive_" + scene_id + ".json")
            scenario = scenario_serialization.load_argoverse_scenario_parquet(
                scene_file
            )
            static_map = ArgoverseStaticMap.from_json(map_file)

            f_traj, f_prob = model_finetune.predict(collate_fn([data]))
            s_traj, s_prob = model_scratch.predict(collate_fn([data]))
            output, _ = model_ssl(ssl_collate_fn([ssl_data]))
            f_traj, f_prob = f_traj.squeeze(0), f_prob.squeeze(0)
            s_traj, s_prob = s_traj.squeeze(0), s_prob.squeeze(0)
            ssl_traj = output["reg"][0][0].cpu().numpy()

            visualize_scenario(
                axes[i][0],
                scenario,
                static_map,
                mask=False,
                preds=ssl_traj,
                prob=None,
            )
            visualize_scenario(
                axes[i][1],
                scenario,
                static_map,
                mask=False,
                preds=s_traj,
                prob=f_prob,
            )
            visualize_scenario(
                axes[i][2],
                scenario,
                static_map,
                mask=False,
                preds=f_traj,
                prob=s_prob,
            )

            x_lim = axes[i][0].get_xlim()
            y_lim = axes[i][0].get_ylim()

            if i == 0 and step == 0:
                texts = ["SSL-Lanes", "Scratch", "Fine-tune (Ours)"]
                for j in range(3):
                    axes[i][j].text(
                        0.5,
                        1.05,
                        texts[j],
                        horizontalalignment="center",
                        verticalalignment="center",
                        fontsize=30,
                        transform=axes[i][j].transAxes,
                    )

            axes[i][0].text(
                0.92,
                0.93,
                r"(\rom{{{}}})".format(i + 1 + step * 4),
                horizontalalignment="center",
                verticalalignment="center",
                fontsize=20,
                transform=axes[i][0].transAxes,
                zorder=5000,
                color="#002fa7",
            )

    plt.tight_layout()
    plt.savefig(
        save_file,
        dpi=300,
        transparent=False,
        bbox_inches="tight",
        pad_inches=0,
    )
    plt.close(fig)