MARS: An Instance-aware, Modular and Realistic Simulator for Autonomous Driving
Apache License 2.0
649 stars 63 forks source link

High memory usage while loading datasets #85

Closed JPenshornTAP closed 7 months ago

JPenshornTAP commented 9 months ago

Hi, while loading the vkitti2 dataset I'm experiencing huge RAM loads of up to 200 GB, on the master and refactor branches. Since the dataset itself is only around 16 GB I was wondering whether this is intended/ necessary. Do you have any ideas why this occurs?

wuzirui commented 9 months ago

It may be caused by the ray-object intersection computation, we will optimize that in the future. For a quick fix, you can delete the unused variables in the dataparsers, there is some redundant code from NSG that is not necessary.

Nplace-su commented 8 months ago

met the same issue here, did you find any walkaround? @JPenshornTAP Thank you.

PolarisKyle commented 8 months ago

met the same issue here,how to compress the cost of RAM like torch.dataloader()

wuzirui commented 8 months ago

Here's a refactored dataparser for the KITTI dataset. However, some minor changes should be made to plug it into the current mars release version. Hope this could help you guys with the memory issue. This script will be in some future release version, stay tuned~

from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Type

import imageio
import numpy as np
import pandas as pd
import torch
from rich.console import Console

from mars.configs.storage_configs.base_storage_config import StorageConfig
from mars.instances.instances import Instances
from mars.utils.neural_scene_graph_helper import box_pts
from nerfstudio.cameras import camera_utils
from nerfstudio.cameras.cameras import Cameras, CameraType
from import (
from import SceneBox
from nerfstudio.plugins.registry_dataparser import DataParserSpecification
from nerfstudio.utils.colors import get_color

CONSOLE = Console(width=120)
_sem2label = {"Misc": -1, "Car": 0, "Van": 0, "Truck": 2, "Tram": 3, "Pedestrian": 4}
camera_ls = [2, 3]

def kitti_string_to_float(str):
    return float(str.split("e")[0]) * 10 ** int(str.split("e")[1])

def get_rotation(roll, pitch, heading):
    s_heading = np.sin(heading)
    c_heading = np.cos(heading)
    rot_z = np.array([[c_heading, -s_heading, 0], [s_heading, c_heading, 0], [0, 0, 1]])

    s_pitch = np.sin(pitch)
    c_pitch = np.cos(pitch)
    rot_y = np.array([[c_pitch, 0, s_pitch], [0, 1, 0], [-s_pitch, 0, c_pitch]])

    s_roll = np.sin(roll)
    c_roll = np.cos(roll)
    rot_x = np.array([[1, 0, 0], [0, c_roll, -s_roll], [0, s_roll, c_roll]])

    rot = np.matmul(rot_z, np.matmul(rot_y, rot_x))

    return rot

def invert_transformation(rot, t):
    t = np.matmul(-rot.T, t)
    inv_translation = np.concatenate([rot.T, t[:, None]], axis=1)
    return np.concatenate([inv_translation, np.array([[0.0, 0.0, 0.0, 1.0]])])

def calib_from_txt(calibration_path):
    Read the calibration files and extract the required transformation matrices and focal length.

        calibration_path (str): The path to the directory containing the calibration files.

        tuple: A tuple containing the following elements:
            traimu2v (np.array): 4x4 transformation matrix from IMU to Velodyne coordinates.
            v2c (np.array): 4x4 transformation matrix from Velodyne to left camera coordinates.
            c2leftRGB (np.array): 4x4 transformation matrix from left camera to rectified left camera coordinates.
            c2rightRGB (np.array): 4x4 transformation matrix from right camera to rectified right camera coordinates.
            focal (float): Focal length of the left camera.
    c2c = []

    # Read and parse the camera-to-camera calibration file
    f = open(os.path.join(calibration_path, "calib_cam_to_cam.txt"), "r")
    cam_to_cam_str =
    [left_cam, right_cam] = cam_to_cam_str.split("S_02: ")[1].split("S_03: ")
    cam_to_cam_ls = [left_cam, right_cam]

    # Extract the transformation matrices for left and right cameras
    for i, cam_str in enumerate(cam_to_cam_ls):
        r_str, t_str = cam_str.split("R_0" + str(i + 2) + ": ")[1].split("\nT_0" + str(i + 2) + ": ")
        t_str = t_str.split("\n")[0]
        R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")])
        R = np.reshape(R, [3, 3])
        t = np.array([kitti_string_to_float(t) for t in t_str.split(" ")])
        Tr = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]])

        t_str_rect, s_rect_part = cam_str.split("\nT_0" + str(i + 2) + ": ")[1].split("\nS_rect_0" + str(i + 2) + ": ")
        s_rect_str, r_rect_part = s_rect_part.split("\nR_rect_0" + str(i + 2) + ": ")
        r_rect_str = r_rect_part.split("\nP_rect_0" + str(i + 2) + ": ")[0]
        R_rect = np.array([kitti_string_to_float(r) for r in r_rect_str.split(" ")])
        R_rect = np.reshape(R_rect, [3, 3])
        t_rect = np.array([kitti_string_to_float(t) for t in t_str_rect.split(" ")])
        Tr_rect = np.concatenate(
            [np.concatenate([R_rect, t_rect[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]]


    c2leftRGB = c2c[0]
    c2rightRGB = c2c[1]

    # Read and parse the Velodyne-to-camera calibration file
    f = open(os.path.join(calibration_path, "calib_velo_to_cam.txt"), "r")
    velo_to_cam_str =
    r_str, t_str = velo_to_cam_str.split("R: ")[1].split("\nT: ")
    t_str = t_str.split("\n")[0]
    R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")])
    R = np.reshape(R, [3, 3])
    t = np.array([kitti_string_to_float(r) for r in t_str.split(" ")])
    v2c = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]])

    # Read and parse the IMU-to-Velodyne calibration file
    f = open(os.path.join(calibration_path, "calib_imu_to_velo.txt"), "r")
    imu_to_velo_str =
    r_str, t_str = imu_to_velo_str.split("R: ")[1].split("\nT: ")
    R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")])
    R = np.reshape(R, [3, 3])
    t = np.array([kitti_string_to_float(r) for r in t_str.split(" ")])
    imu2v = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]])

    # Extract the focal length of the left camera
    focal = kitti_string_to_float(left_cam.split("P_rect_02: ")[1].split()[0])

    return imu2v, v2c, c2leftRGB, c2rightRGB, focal

def tracking_calib_from_txt(calibration_path):
    Extract tracking calibration information from a KITTI tracking calibration file.

    This function reads a KITTI tracking calibration file and extracts the relevant
    calibration information, including projection matrices and transformation matrices
    for camera, LiDAR, and IMU coordinate systems.

        calibration_path (str): Path to the KITTI tracking calibration file.

        dict: A dictionary containing the following calibration information:
            P0, P1, P2, P3 (np.array): 3x4 projection matrices for the cameras.
            Tr_cam2camrect (np.array): 4x4 transformation matrix from camera to rectified camera coordinates.
            Tr_velo2cam (np.array): 4x4 transformation matrix from LiDAR to camera coordinates.
            Tr_imu2velo (np.array): 4x4 transformation matrix from IMU to LiDAR coordinates.
    # Read the calibration file
    f = open(calibration_path)
    calib_str =

    # Process the calibration data
    calibs = []
    for calibration in calib_str:
        calibs.append(np.array([kitti_string_to_float(val) for val in calibration.split()[1:]]))

    # Extract the projection matrices
    P0 = np.reshape(calibs[0], [3, 4])
    P1 = np.reshape(calibs[1], [3, 4])
    P2 = np.reshape(calibs[2], [3, 4])
    P3 = np.reshape(calibs[3], [3, 4])

    # Extract the transformation matrix for camera to rectified camera coordinates
    Tr_cam2camrect = np.eye(4)
    R_rect = np.reshape(calibs[4], [3, 3])
    Tr_cam2camrect[:3, :3] = R_rect

    # Extract the transformation matrices for LiDAR to camera and IMU to LiDAR coordinates
    Tr_velo2cam = np.concatenate([np.reshape(calibs[5], [3, 4]), np.array([[0.0, 0.0, 0.0, 1.0]])], axis=0)
    Tr_imu2velo = np.concatenate([np.reshape(calibs[6], [3, 4]), np.array([[0.0, 0.0, 0.0, 1.0]])], axis=0)

    return {
        "P0": P0,
        "P1": P1,
        "P2": P2,
        "P3": P3,
        "Tr_cam2camrect": Tr_cam2camrect,
        "Tr_velo2cam": Tr_velo2cam,
        "Tr_imu2velo": Tr_imu2velo,

def get_poses_calibration(basedir, oxts_path_tracking=None, selected_frames=None):
    Extract poses and calibration information from the KITTI dataset.

    This function processes the OXTS data (GPS/IMU) and extracts the
    pose information (translation and rotation) for each frame. It also
    retrieves the calibration information (transformation matrices and focal length)
    required for further processing.

        basedir (str): The base directory containing the KITTI dataset.
        oxts_path_tracking (str, optional): Path to the OXTS data file for tracking sequences.
            If not provided, the function will look for OXTS data in the basedir.
        selected_frames (list, optional): A list of frame indices to process.
            If not provided, all frames in the dataset will be processed.

        tuple: A tuple containing the following elements:
            poses (np.array): An array of 4x4 pose matrices representing the vehicle's
                position and orientation for each frame (IMU pose).
            calibrations (dict): A dictionary containing the transformation matrices
                and focal length obtained from the calibration files.
            focal (float): The focal length of the left camera.

    def oxts_to_pose(oxts):
        OXTS (Oxford Technical Solutions) data typically refers to the data generated by an Inertial and GPS Navigation System (INS/GPS) that is used to provide accurate position, orientation, and velocity information for a moving platform, such as a vehicle. In the context of the KITTI dataset, OXTS data is used to provide the ground truth for the vehicle's trajectory and 6 degrees of freedom (6-DoF) motion, which is essential for evaluating and benchmarking various computer vision and robotics algorithms, such as visual odometry, SLAM, and object detection.

        The OXTS data contains several important measurements:

        1. Latitude, longitude, and altitude: These are the global coordinates of the moving platform.
        2. Roll, pitch, and yaw (heading): These are the orientation angles of the platform, usually given in Euler angles.
        3. Velocity (north, east, and down): These are the linear velocities of the platform in the local navigation frame.
        4. Accelerations (ax, ay, az): These are the linear accelerations in the platform's body frame.
        5. Angular rates (wx, wy, wz): These are the angular rates (also known as angular velocities) of the platform in its body frame.

        In the KITTI dataset, the OXTS data is stored as plain text files with each line corresponding to a timestamp. Each line in the file contains the aforementioned measurements, which are used to compute the ground truth trajectory and 6-DoF motion of the vehicle. This information can be further used for calibration, data synchronization, and performance evaluation of various algorithms.
        poses = []

        def latlon_to_mercator(lat, lon, s):
            Converts latitude and longitude coordinates to Mercator coordinates (x, y) using the given scale factor.

            The Mercator projection is a widely used cylindrical map projection that represents the Earth's surface
            as a flat, rectangular grid, distorting the size of geographical features in higher latitudes.
            This function uses the scale factor 's' to control the amount of distortion in the projection.

                lat (float): Latitude in degrees, range: -90 to 90.
                lon (float): Longitude in degrees, range: -180 to 180.
                s (float): Scale factor, typically the cosine of the reference latitude.

                list: A list containing the Mercator coordinates [x, y] in meters.
            r = 6378137.0  # the Earth's equatorial radius in meters
            x = s * r * ((np.pi * lon) / 180)
            y = s * r * np.log(np.tan((np.pi * (90 + lat)) / 360))
            return [x, y]

        # Compute the initial scale and pose based on the selected frames
        if selected_frames is None:
            lat0 = oxts[0][0]
            scale = np.cos(lat0 * np.pi / 180)
            pose_0_inv = None
            oxts0 = oxts[selected_frames[0][0]]
            lat0 = oxts0[0]
            scale = np.cos(lat0 * np.pi / 180)

            pose_i = np.eye(4)

            [x, y] = latlon_to_mercator(oxts0[0], oxts0[1], scale)
            z = oxts0[2]
            translation = np.array([x, y, z])
            rotation = get_rotation(oxts0[3], oxts0[4], oxts0[5])
            pose_i[:3, :] = np.concatenate([rotation, translation[:, None]], axis=1)
            pose_0_inv = invert_transformation(pose_i[:3, :3], pose_i[:3, 3])

        # Iterate through the OXTS data and compute the corresponding pose matrices
        for oxts_val in oxts:
            pose_i = np.zeros([4, 4])
            pose_i[3, 3] = 1

            [x, y] = latlon_to_mercator(oxts_val[0], oxts_val[1], scale)
            z = oxts_val[2]
            translation = np.array([x, y, z])

            roll = oxts_val[3]
            pitch = oxts_val[4]
            heading = oxts_val[5]
            rotation = get_rotation(roll, pitch, heading)  # (3,3)

            pose_i[:3, :] = np.concatenate([rotation, translation[:, None]], axis=1)  # (4, 4)
            if pose_0_inv is None:
                pose_0_inv = invert_transformation(pose_i[:3, :3], pose_i[:3, 3])

            pose_i = np.matmul(pose_0_inv, pose_i)

        return np.array(poses)

    # If there is no tracking path specified, use the default path
    if oxts_path_tracking is None:
        oxts_path = os.path.join(basedir, "oxts/data")
        oxts = np.array([np.loadtxt(os.path.join(oxts_path, file)) for file in sorted(os.listdir(oxts_path))])
        calibration_path = os.path.dirname(basedir)

        calibrations = calib_from_txt(calibration_path)

        focal = calibrations[4]

        poses = oxts_to_pose(oxts)

    # If a tracking path is specified, use it to load OXTS data and compute the poses
        oxts_tracking = np.loadtxt(oxts_path_tracking)
        poses = oxts_to_pose(oxts_tracking)  # (n_frames, 4, 4)
        calibrations = None
        focal = None
        # Set velodyne close to z = 0
        # poses[:, 2, 3] -= 0.8

    # Return the poses, calibrations, and focal length
    return poses, calibrations, focal

def get_camera_poses_tracking(poses_velo_w_tracking, tracking_calibration, selected_frames, scene_no=None):
    exp = False
    camera_poses = []

    opengl2kitti = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])

    start_frame = selected_frames[0]
    end_frame = selected_frames[1]

    # Debug Camera offset
    if scene_no == 2:
        yaw = np.deg2rad(0.7)  ## Affects camera rig roll: High --> counterclockwise
        pitch = np.deg2rad(-0.5)  ## Affects camera rig yaw: High --> Turn Right
        # pitch = np.deg2rad(-0.97)
        roll = np.deg2rad(0.9)  ## Affects camera rig pitch: High -->  up
        # roll = np.deg2rad(1.2)
    elif scene_no == 1:
        if exp:
            yaw = np.deg2rad(0.3)  ## Affects camera rig roll: High --> counterclockwise
            pitch = np.deg2rad(-0.6)  ## Affects camera rig yaw: High --> Turn Right
            # pitch = np.deg2rad(-0.97)
            roll = np.deg2rad(0.75)  ## Affects camera rig pitch: High -->  up
            # roll = np.deg2rad(1.2)
            yaw = np.deg2rad(0.5)  ## Affects camera rig roll: High --> counterclockwise
            pitch = np.deg2rad(-0.5)  ## Affects camera rig yaw: High --> Turn Right
            roll = np.deg2rad(0.75)  ## Affects camera rig pitch: High -->  up
        yaw = np.deg2rad(0.05)
        pitch = np.deg2rad(-0.75)
        # pitch = np.deg2rad(-0.97)
        roll = np.deg2rad(1.05)
        # roll = np.deg2rad(1.2)

    cam_debug = np.eye(4)
    cam_debug[:3, :3] = get_rotation(roll, pitch, yaw)

    Tr_cam2camrect = tracking_calibration["Tr_cam2camrect"]
    Tr_cam2camrect = np.matmul(Tr_cam2camrect, cam_debug)
    Tr_camrect2cam = invert_transformation(Tr_cam2camrect[:3, :3], Tr_cam2camrect[:3, 3])
    Tr_velo2cam = tracking_calibration["Tr_velo2cam"]
    Tr_cam2velo = invert_transformation(Tr_velo2cam[:3, :3], Tr_velo2cam[:3, 3])

    camera_poses_imu = []
    for cam in camera_ls:
        Tr_camrect2cam_i = tracking_calibration["Tr_camrect2cam0" + str(cam)]
        Tr_cam_i2camrect = invert_transformation(Tr_camrect2cam_i[:3, :3], Tr_camrect2cam_i[:3, 3])
        # transform camera axis from kitti to opengl for nerf:
        cam_i_camrect = np.matmul(Tr_cam_i2camrect, opengl2kitti)
        cam_i_cam0 = np.matmul(Tr_camrect2cam, cam_i_camrect)
        cam_i_velo = np.matmul(Tr_cam2velo, cam_i_cam0)

        cam_i_w = np.matmul(poses_velo_w_tracking, cam_i_velo)

    for i, cam in enumerate(camera_ls):
        for frame_no in range(start_frame, end_frame + 1):

    return np.array(camera_poses)

def get_obj_pose_tracking(tracklet_path, poses_imu_tracking, calibrations, selected_frames, transform_matrix):
    Extracts object pose information from the KITTI motion tracking dataset for the specified frames.

    tracklet_path : str
        Path to the text file containing tracklet information.  A tracklet is a small sequence of object positions and orientations over time, often used in the context of object tracking and motion estimation in computer vision. In a dataset, a tracklet usually represents a single object's pose information across multiple consecutive frames. This information includes the object's position, orientation (usually as rotation around the vertical axis, i.e., yaw angle), and other attributes like object type, dimensions, etc.  In the KITTI dataset, tracklets are used to store and provide ground truth information about dynamic objects in the scene, such as cars, pedestrians, and cyclists.

    poses_imu_tracking : list of numpy arrays
        A list of 4x4 transformation matrices representing the poses (positions and orientations) of the ego vehicle A(the main vehicle equipped with sensors) in Inertial Measurement Unit (IMU) coordinates at different time instances (frames). Each matrix in the list corresponds to a single frame in the dataset.

    calibrations : dict
        Dictionary containing calibration information:
            - "Tr_velo2cam": 3x4 transformation matrix from Velodyne coordinates to camera coordinates.
            - "Tr_imu2velo": 3x4 transformation matrix from IMU coordinates to Velodyne coordinates.

    selected_frames : list of int
        List of two integers specifying the start and end frames to process.

    Object Instance instance.

    # Helper function to generate a rotation matrix around the y-axis
    def roty_matrix(roty):
        c = np.cos(roty)
        s = np.sin(roty)
        return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])

    # Extract calibration data
    velo2cam = calibrations["Tr_velo2cam"]
    imu2velo = calibrations["Tr_imu2velo"]
    cam2velo = invert_transformation(velo2cam[:3, :3], velo2cam[:3, 3])
    velo2imu = invert_transformation(imu2velo[:3, :3], imu2velo[:3, 3])

    start_frame = selected_frames[0]
    end_frame = selected_frames[1]

    # Read tracklets from file
    f = open(tracklet_path)
    tracklets_str =

    track_ids, category_ids, lengths, heights, widths = [], [], [], [], []
    track2index = {}
    centers_list = []
    yaws_list = []
    timestamps_list = []

    # Extract metadata for all objects in the scene
    for tracklet in tracklets_str:
        tracklet = tracklet.split()
        if float(tracklet[1]) < 0:
        frame_id = int(tracklet[0])
        if frame_id < start_frame or frame_id > end_frame:

        track_id = int(tracklet[1])

        if not tracklet[2] in _sem2label:
        type = _sem2label[tracklet[2]]

        if not track_id in track2index:  # extract metadata
            height = float(tracklet[10])
            width = float(tracklet[11])
            length = float(tracklet[12])
            track2index[track_id] = len(track_ids)

            # create empty tensors for centers and yaws
            centers_list.append(torch.zeros((end_frame - start_frame + 1, 3), dtype=torch.float64))
            yaws_list.append(torch.zeros((end_frame - start_frame + 1, 1), dtype=torch.float64))
            timestamps_list.append(-1.0 * torch.ones((end_frame - start_frame + 1, 1), dtype=torch.float64))

        # extract object track
        index = track2index[track_id]
        timestamps_list[index][frame_id - start_frame] = float(frame_id - start_frame)

        # Initialize a 4x4 identity matrix for object pose in camera coordinates
        pose = np.array(
            [float(tracklet[13]), float(tracklet[14]), float(tracklet[15]), float(tracklet[16])]
        )  # x,y,z,yaw
        obj_pose_c = np.eye(4)
        obj_pose_c[:3, 3] = pose[:3]
        roty = pose[3]
        obj_pose_c[:3, :3] = roty_matrix(roty)

        # Transform object pose from camera coordinates to IMU coordinates
        obj_pose_imu = np.matmul(velo2imu, np.matmul(cam2velo, obj_pose_c))

        # Get the IMU pose for the corresponding frame
        pose_imu_w_frame_i = poses_imu_tracking[int(frame_id)]

        # Calculate the world pose of the object
        pose_obj_w_i = np.matmul(pose_imu_w_frame_i, obj_pose_imu)
        pose_obj_w_i = np.matmul(transform_matrix, pose_obj_w_i)
        # pose_obj_w_i[:, 3] *= scale_factor

        # Calculate the approximate yaw angle of the object in the world frame
        yaw_aprox = -np.arctan2(pose_obj_w_i[1, 0], pose_obj_w_i[0, 0])

        # store center and yaw
        centers_list[index][frame_id - start_frame] = torch.from_numpy(pose_obj_w_i[:3, 3])
        yaws_list[index][frame_id - start_frame] = yaw_aprox

    max_obj_per_frame = len(track_ids)

    instances_obj = Instances(
        centers=torch.stack(centers_list, dim=0),
        yaws=torch.stack(yaws_list, dim=0),
        timestamps=torch.stack(timestamps_list, dim=0),
        valid_mask=torch.stack(timestamps_list) > 0,

    return instances_obj

def get_scene_images_tracking(
    [start_frame, end_frame] = selected_frames
    # imgs = []
    img_name = []
    depth_name = []
    semantic_name = []
    normal_name = []

    left_img_path = os.path.join(os.path.join(tracking_path, "image_02"), sequence)
    right_img_path = os.path.join(os.path.join(tracking_path, "image_03"), sequence)

    left_normal_path = os.path.join(os.path.join(tracking_path, "normal_02"), sequence)
    right_normal_path = os.path.join(os.path.join(tracking_path, "normal_03"), sequence)

    if pred_normals:
        for frame_dir in [left_normal_path, right_normal_path]:
            for frame_no in range(len(os.listdir(left_normal_path))):
                if start_frame <= frame_no <= end_frame:
                    frame = sorted(os.listdir(frame_dir))[frame_no]
                    fname = os.path.join(frame_dir, frame)

    if use_depth:
        left_depth_path = os.path.join(os.path.join(tracking_path, "completion_02"), sequence)
        right_depth_path = os.path.join(os.path.join(tracking_path, "completion_03"), sequence)

    for frame_dir in [left_img_path, right_img_path]:
        for frame_no in range(len(os.listdir(left_img_path))):
            if start_frame <= frame_no <= end_frame:
                frame = sorted(os.listdir(frame_dir))[frame_no]
                fname = os.path.join(frame_dir, frame)
                # imgs.append(imageio.imread(fname))

    if use_depth:
        for frame_dir in [left_depth_path, right_depth_path]:
            for frame_no in range(len(os.listdir(left_depth_path))):
                if start_frame <= frame_no <= end_frame:
                    frame = sorted(os.listdir(frame_dir))[frame_no]
                    fname = os.path.join(frame_dir, frame)

    if use_semantic:
        frame_dir = os.path.join(semantic_path, "train", sequence)
        for _ in range(2):
            for frame_no in range(len(os.listdir(frame_dir))):
                if start_frame <= frame_no <= end_frame:
                    frame = sorted(os.listdir(frame_dir))[frame_no]
                    fname = os.path.join(frame_dir, frame)

    # imgs = (np.maximum(np.minimum(np.array(imgs), 255), 0) / 255.0).astype(np.float32)
    return img_name, depth_name, semantic_name, normal_name

def get_rays_np(H, W, focal, c2w):
    """Get ray origins, directions from a pinhole camera."""
    # Numpy Version
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing="xy")
    dirs = np.stack([(i - W * 0.5) / focal, -(j - H * 0.5) / focal, -np.ones_like(i)], -1)
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1)
    rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d))
    return rays_o, rays_d

def rotate_yaw(p, yaw):
    """Rotates p with yaw in the given coord frame with y being the relevant axis and pointing downwards

        p: 3D points in a given frame [N_pts, N_frames, 3]/[N_pts, N_frames, N_samples, 3]
        yaw: Rotation angle

        p: Rotated points [N_pts, N_frames, N_samples, 3]
    # p of size [batch_rays, n_obj, samples, xyz]
    if len(p.shape) < 4:
        # p = p[..., tf.newaxis, :]
        p = p.unsqueeze(-2)

    c_y = torch.cos(yaw)
    s_y = torch.sin(yaw)

    if len(c_y.shape) < 3:
        c_y = c_y.unsqueeze(-1)
        s_y = s_y.unsqueeze(-1)

    # c_y = tf.cos(yaw)[..., tf.newaxis]
    # s_y = tf.sin(yaw)[..., tf.newaxis]

    p_x = c_y * p[..., 0] - s_y * p[..., 2]
    p_y = p[..., 1]
    p_z = s_y * p[..., 0] + c_y * p[..., 2]

    # return tf.concat([p_x[..., tf.newaxis], p_y[..., tf.newaxis], p_z[..., tf.newaxis]], axis=-1)
    return[p_x.unsqueeze(-1), p_y.unsqueeze(-1), p_z.unsqueeze(-1)], dim=-1)

def scale_frames(p, sc_factor, inverse=False):
    """Scales points given in N_frames in each dimension [xyz] for each frame or rescales for inverse==True

        p: Points given in N_frames frames [N_points, N_frames, N_samples, 3]
        sc_factor: Scaling factor for new frame [N_points, N_frames, 3]
        inverse: Inverse scaling if true, bool

        p_scaled: Points given in N_frames rescaled frames [N_points, N_frames, N_samples, 3]
    # Take 150% of bbox to include shadows etc.
    dim = torch.tensor([1.0, 1.0, 1.0]) * sc_factor
    # dim = tf.constant([0.1, 0.1, 0.1]) * sc_factor

    half_dim = dim / 2
    scaling_factor = (1 / (half_dim + 1e-9)).unsqueeze(-2)
    # scaling_factor = (1 / (half_dim + 1e-9))[:, :, tf.newaxis, :]

    if not inverse:
        p_scaled = scaling_factor * p
        p_scaled = (1.0 / scaling_factor) * p

    return p_scaled

def get_all_ray_3dbox_intersection(rays_rgb, obj_meta_tensor, chunk, local=False, obj_to_remove=-100):
    """get all rays hitting an oject given 3D multi-object-tracking results of a sequence

        rays_rgb: All rays
        obj_meta_tensor: Metadata of all objects
        chunk: No. of rays processed at the same time
        local: Limit used memory if processed on a local machine with limited CPU/GPU resources
        obj_to_remove: If object should be removed from the set of rays

        rays_on_obj: Set of all rays hitting at least one object
        rays_to_remove: Set of all rays hitting an object, that should not be trained

    print("Removing object ", obj_to_remove)
    rays_on_obj = np.array([])
    rays_to_remove = np.array([])
    _batch_sz_inter = chunk if not local else 5000  # args.chunk
    _only_intersect_rays_rgb = rays_rgb[0][None]
    _n_rays = rays_rgb.shape[0]
    _n_obj = (rays_rgb.shape[1] - 3) // 2
    _n_bt = np.ceil(_n_rays / _batch_sz_inter).astype(np.int32)

    for i in range(_n_bt):
        _tf_rays_rgb = torch.from_numpy(rays_rgb[i * _batch_sz_inter : (i + 1) * _batch_sz_inter]).to(torch.float32)
        # _tf_rays_rgb = tf.cast(rays_rgb[i * _batch_sz_inter:(i + 1) * _batch_sz_inter], tf.float32)
        _n_bt_i = _tf_rays_rgb.shape[0]
        _rays_bt = [_tf_rays_rgb[:, 0, :], _tf_rays_rgb[:, 1, :]]
        _objs = torch.reshape(_tf_rays_rgb[:, 3:, :], (_n_bt_i, _n_obj, 6))
        # _objs = tf.reshape(_tf_rays_rgb[:, 3:, :], [_n_bt_i, _n_obj, 6])
        _obj_pose = _objs[..., :3]
        _obj_theta = _objs[..., 3]
        _obj_id = _objs[..., 4].to(torch.int64)
        # _obj_id = tf.cast(_objs[..., 4], tf.int32)
        _obj_meta = torch.index_select(obj_meta_tensor, 0, _obj_id.reshape(-1)).reshape(
            -1, _obj_id.shape[1], obj_meta_tensor.shape[1]
        # _obj_meta = tf.gather(obj_meta_tensor, _obj_id, axis=0)
        _obj_track_id = _obj_meta[..., 0].unsqueeze(-1)
        _obj_dim = _obj_meta[..., 1:4]

        box_points_insters = box_pts(_rays_bt, _obj_pose, _obj_theta, _obj_dim, one_intersec_per_ray=False)
        _mask = box_points_insters[8]
        if _mask is not None:
            if rays_on_obj.any():
                rays_on_obj = np.concatenate([rays_on_obj, np.array(i * _batch_sz_inter + (_mask[:, 0]).cpu().numpy())])
                rays_on_obj = np.array(i * _batch_sz_inter + _mask[:, 0].cpu().numpy())
            if obj_to_remove is not None:
                _hit_id = _obj_track_id[_mask]
                import pdb

                # _hit_id = tf.gather_nd(_obj_track_id, _mask)
                # bool_remove = tf.equal(_hit_id, obj_to_remove)
                bool_remove = np.equal(_hit_id, obj_to_remove)
                if any(bool_remove):
                    # _remove_mask = tf.gather_nd(_mask, tf.where(bool_remove))
                    _remove_mask = np.array(_mask[:, 0])[np.where(np.equal(_hit_id, obj_to_remove))[0]]
                    if rays_to_remove.any():
                        rays_to_remove = np.concatenate([rays_to_remove, np.array(i * _batch_sz_inter + _remove_mask)])
                        rays_to_remove = np.array(i * _batch_sz_inter + _remove_mask)

    return rays_on_obj, rays_to_remove, box_points_insters

class MarsKittiDataParserConfig(DataParserConfig):
    """nerual scene graph dataset parser config"""

    _target: Type = field(default_factory=lambda: MarsKittiDataparser)
    """target class to instantiate"""
    data: Path = Path("data/kitti/training/image_02/0005")
    """Directory specifying location of data."""
    sequence_id: str = "0006"
    """Sequence ID of the KITTI data."""
    scale_factor: float = 1
    """How much to scale the camera origins by."""
    scene_scale: float = 1.0
    """How much to scale the region of interest by."""
    alpha_color: str = "white"
    """alpha color of background"""
    first_frame: int = 65
    """specifies the beginning of a sequence if not the complete scene is taken as Input"""
    last_frame: int = 120
    """specifies the end of a sequence"""
    box_scale: float = 1.5
    """Maximum scale for bboxes to include shadows"""
    near_plane: float = 0.5
    """specifies the distance from the last pose to the near plane"""
    far_plane: float = 150.0
    """specifies the distance from the last pose to the far plane"""
    netchunk: int = 1024 * 64
    """number of pts sent through network in parallel, decrease if running out of memory"""
    chunk: int = 1024 * 32
    """number of rays processed in parallel, decrease if running out of memory"""
    use_car_latents: bool = True
    car_object_latents_path: Optional[Path] = Path(f"pretrain/car_nerf/kitti/latent_codes{sequence_id[-2:]}.pt")
    """path of car object latent codes"""
    car_nerf_state_dict_path: Optional[Path] = Path("pretrain/car_nerf/kitti/epoch_670.ckpt")
    """path of car nerf state dicts"""
    use_depth: bool = True
    """whether the training loop contains depth"""
    split_setting: str = "reconstruction"
    use_semantic: bool = False
    """whether to use semantic information"""
    semantic_path: Optional[Path] = Path("")
    """path of semantic inputs"""
    semantic_mask_classes: List[str] = field(default_factory=lambda: [])
    """semantic classes that do not generate gradient to the background model"""
    storage_config: StorageConfig = StorageConfig()
    """machine specific configs."""
    pred_normals: bool = False
    """whether to use normals information"""

class MarsKittiDataparser(DataParser):
    """nerual scene graph kitti Dataset"""

    config: MarsKittiDataParserConfig

    def __init__(self, config: MarsKittiDataParserConfig):
        super().__init__(config=config) Path =
        self.scale_factor: float = config.scale_factor
        self.alpha_color = config.alpha_color
        self.selected_frames = [config.first_frame, config.last_frame]
        self.use_time = False
        self.remove = -1
        self.near = config.near_plane
        self.far = config.far_plane
        self.netchunk = config.netchunk
        self.chunk = config.chunk
        self.use_semantic = config.use_semantic
        self.semantic_path: Path = config.storage_config.datapath_dict["KITTI-MOT-home"] / "panoptic_maps"
        self.pred_normals = config.pred_normals

        if "KITTI-MOT-home" in self.config.storage_config.datapath_dict:
   = (
                / "training"
                / "image_02"
                / self.config.sequence_id

        if "CarNeRF-latents" in self.config.storage_config.datapath_dict:
            self.config.car_object_latents_path = (
                / "kitti_mot"
                / "latents"
                / f"latent_codes{self.config.sequence_id[-2:]}.pt"
            assert self.config.car_object_latents_path.exists()

        if "CarNeRF-pretrained-model" in self.config.storage_config.datapath_dict:
            self.config.car_nerf_state_dict_path = (
                self.config.storage_config.datapath_dict["CarNeRF-pretrained-model"] / "epoch_670.ckpt"
            assert self.config.car_nerf_state_dict_path.exists()

    def _generate_dataparser_outputs(self, split="train"):
        visible_objects_ls = []
        objects_meta_ls = []
        semantic_meta = []

        kitti2vkitti = np.array(
            [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, -1.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
        if self.alpha_color is not None:
            alpha_color_tensor = get_color(self.alpha_color)
            alpha_color_tensor = None

        basedir = str(
        scene_id = basedir[-4:]  # check
        kitti_scene_no = int(scene_id)
        tracking_path = basedir[:-13]  # check
        calibration_path = os.path.join(os.path.join(tracking_path, "calib"), scene_id + ".txt")
        oxts_path_tracking = os.path.join(os.path.join(tracking_path, "oxts"), scene_id + ".txt")
        tracklet_path = os.path.join(os.path.join(tracking_path, "label_02"), scene_id + ".txt")

        tracking_calibration = tracking_calib_from_txt(calibration_path)
        focal_X = tracking_calibration["P2"][0, 0]
        focal_Y = tracking_calibration["P2"][1, 1]
        poses_imu_w_tracking, _, _ = get_poses_calibration(basedir, oxts_path_tracking)  # (n_frames, 4, 4) imu pose

        tr_imu2velo = tracking_calibration["Tr_imu2velo"]
        tr_velo2imu = invert_transformation(tr_imu2velo[:3, :3], tr_imu2velo[:3, 3])
        poses_velo_w_tracking = np.matmul(poses_imu_w_tracking, tr_velo2imu)  # (n_frames, 4, 4) velodyne pose

        if self.use_semantic:
            semantics = pd.read_csv(
                os.path.join(self.semantic_path, "colors", scene_id + ".txt"),
                sep=" ",

        if self.use_semantic:
            semantics = semantics.loc[~semantics["Category"].isin(self.config.semantic_mask_classes)]
            semantic_meta = Semantics(
                colors=torch.tensor(semantics.iloc[:, 1:].values),

        # Get camera Poses   camare id: 02, 03
        for cam_i in range(2):
            transformation = np.eye(4)
            projection = tracking_calibration["P" + str(cam_i + 2)]  # rectified camera coordinate system -> image
            K_inv = np.linalg.inv(projection[:3, :3])
            R_t = projection[:3, 3]

            t_crect2c = np.matmul(K_inv, R_t)
            # t_crect2c = 1./projection[[0, 1, 2],[0, 1, 2]] * projection[:, 3]
            transformation[:3, 3] = t_crect2c
            tracking_calibration["Tr_camrect2cam0" + str(cam_i + 2)] = transformation

        sequ_frames = self.selected_frames

        cam_poses_tracking = get_camera_poses_tracking(
            poses_velo_w_tracking, tracking_calibration, sequ_frames, kitti_scene_no
        # cam_poses_tracking[..., :3, 3] *= self.scale_factor

        # Orients and centers the poses
        oriented = torch.from_numpy(np.array(cam_poses_tracking).astype(np.float32))  # (n_frames, 3, 4)
        oriented, transform_matrix = camera_utils.auto_orient_and_center_poses(
        )  # oriented (n_frames, 3, 4), transform_matrix (3, 4)
        row = torch.tensor([0, 0, 0, 1], dtype=torch.float32)
        zeros = torch.zeros(oriented.shape[0], 1, 4)
        oriented =[oriented, zeros], dim=1)
        oriented[:, -1] = row  # (n_frames, 4, 4)
        transform_matrix =[transform_matrix, row[None, :]], dim=0)  # (4, 4)
        cam_poses_tracking = oriented.numpy()
        transform_matrix = transform_matrix.numpy()
        image_filenames, depth_name, semantic_name, normal_name = get_scene_images_tracking(

        # Get Object poses
        instances_obj = get_obj_pose_tracking(
            tracklet_path, poses_imu_w_tracking, tracking_calibration, sequ_frames, transform_matrix
        # deleted visible_objects_, objects_meta_,

        # # Align Axis with vkitti axis
        poses = np.matmul(kitti2vkitti, cam_poses_tracking).astype(np.float32)
        poses[..., :3, 3] *= self.scale_factor
        instances_obj.centers[:, :, 2] *= -1
        instances_obj.centers[:, :, [0, 1, 2]] = instances_obj.centers[:, :, [0, 2, 1]]
        instances_obj.centers *= self.scale_factor
        instances_obj.width *= (
            self.config.box_scale * self.scale_factor
        )  # scale up the bbox to include shadows, etc.; scale bbox according to the scene scale factor
        instances_obj.length *= self.config.box_scale * self.scale_factor
        instances_obj.height *= self.scale_factor  # height is not multiplied by box scale

        N_obj = instances_obj.num_instances

        counts = np.arange(instances_obj.len_tracklet * 2).reshape(2, -1)
        frame_timestamps = np.concatenate(
            [np.arange(instances_obj.len_tracklet), np.arange(instances_obj.len_tracklet)]
        i_test = np.array([(idx + 1) % 4 == 0 for idx in counts[0]])
        i_test = np.concatenate((i_test, i_test))
        if self.config.split_setting == "reconstruction":
            i_train = np.ones(instances_obj.len_tracklet * 2, dtype=bool)
        elif self.config.split_setting == "nvs-75":
            i_train = ~i_test
        elif self.config.split_setting == "nvs-50":
            i_train = np.array([(idx + 1) % 4 > 1 for idx in counts[0]])
            i_train = np.concatenate((i_train, i_train))
        elif self.config.split_setting == "nvs-25":
            i_train = np.array([idx % 4 == 0 for idx in counts[0]])
            i_train = np.concatenate((i_train, i_train))
            raise ValueError("No such split method")

        counts = counts.reshape(-1)
        i_train = counts[i_train]
        train_timestamps = frame_timestamps[i_train]
        i_test = counts[i_test]
        test_timestamps = frame_timestamps[i_test]

        test_load_image = imageio.imread(image_filenames[0])
        image_height, image_width = test_load_image.shape[:2]
        cx, cy = image_width / 2.0, image_height / 2.0

        if split == "train":
            indices = i_train
            timestamps = train_timestamps
        elif split == "val":
            indices = i_test
            timestamps = test_timestamps
        elif split == "test":
            indices = i_test
            timestamps = test_timestamps
            raise ValueError(f"Unknown dataparser split {split}")

        image_filenames = [image_filenames[i] for i in indices]
        depth_filenames = [depth_name[i] for i in indices] if self.config.use_depth else None
        normal_filenames = [normal_name[i] for i in indices] if self.pred_normals else None
        if self.use_semantic:
            semantic_meta.filenames = [semantic_name[i] for i in indices]
        poses = poses[indices]
        if self.config.use_car_latents:
            if not self.config.car_object_latents_path.exists():
                CONSOLE.print("[yello]Error: latents not exist")
            car_latents = torch.load(str(self.config.car_object_latents_path))
            track_car_latents = {}
            track_car_latents_mean = {}
            for k, idx in enumerate(car_latents["indices"]):
                if sequ_frames[0] <= idx["fid"] <= sequ_frames[1]:
                    if idx["oid"] in track_car_latents.keys():
                        track_car_latents[idx["oid"]] =
                            [track_car_latents[idx["oid"]], car_latents["latents"][k].unsqueeze(-1)], dim=-1
                        track_car_latents[idx["oid"]] = car_latents["latents"][k].unsqueeze(-1)
            for k in track_car_latents.keys():
                track_car_latents_mean[k] = track_car_latents[k][..., -1]

            car_latents = None

        aabb_scale = self.config.scene_scale
        scene_box = SceneBox(
                [[-aabb_scale, -aabb_scale, -aabb_scale], [aabb_scale, aabb_scale, aabb_scale]], dtype=torch.float32

        cameras = Cameras(
            camera_to_worlds=torch.from_numpy(poses[:, :3, :4]),

        dataparser_outputs = DataparserOutputs(
                "depth_filenames": depth_filenames,
                "obj_class": instances_obj.unique_category_ids,
                "scale_factor": self.scale_factor,
                "semantics": semantic_meta,
                "normal_filenames": normal_filenames,
                "instances": instances_obj,
                "frame_timestamps": timestamps,

        if self.config.use_car_latents:
                    "car_latents": track_car_latents_mean,
                    "car_nerf_state_dict_path": self.config.car_nerf_state_dict_path,

        print("finished data parsing")
        return dataparser_outputs

KittiParserSpec = DataParserSpecification(config=MarsKittiDataParserConfig())
Nplace-su commented 8 months ago

Here's a refactored dataparser for the KITTI dataset. However, some minor changes should be made to plug it into the current mars release version. Hope this could help you guys with the memory issue. This script will be in some future release version, stay tuned~

from __future__ import annotations

import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Type

import imageio
import numpy as np
import pandas as pd
import torch
from rich.console import Console

from mars.configs.storage_configs.base_storage_config import StorageConfig
from mars.instances.instances import Instances
from mars.utils.neural_scene_graph_helper import box_pts
from nerfstudio.cameras import camera_utils
from nerfstudio.cameras.cameras import Cameras, CameraType
from import (
from import SceneBox
from nerfstudio.plugins.registry_dataparser import DataParserSpecification
from nerfstudio.utils.colors import get_color

CONSOLE = Console(width=120)
_sem2label = {"Misc": -1, "Car": 0, "Van": 0, "Truck": 2, "Tram": 3, "Pedestrian": 4}
camera_ls = [2, 3]

def kitti_string_to_float(str):
    return float(str.split("e")[0]) * 10 ** int(str.split("e")[1])

def get_rotation(roll, pitch, heading):
    s_heading = np.sin(heading)
    c_heading = np.cos(heading)
    rot_z = np.array([[c_heading, -s_heading, 0], [s_heading, c_heading, 0], [0, 0, 1]])

    s_pitch = np.sin(pitch)
    c_pitch = np.cos(pitch)
    rot_y = np.array([[c_pitch, 0, s_pitch], [0, 1, 0], [-s_pitch, 0, c_pitch]])

    s_roll = np.sin(roll)
    c_roll = np.cos(roll)
    rot_x = np.array([[1, 0, 0], [0, c_roll, -s_roll], [0, s_roll, c_roll]])

    rot = np.matmul(rot_z, np.matmul(rot_y, rot_x))

    return rot

def invert_transformation(rot, t):
    t = np.matmul(-rot.T, t)
    inv_translation = np.concatenate([rot.T, t[:, None]], axis=1)
    return np.concatenate([inv_translation, np.array([[0.0, 0.0, 0.0, 1.0]])])

def calib_from_txt(calibration_path):
    Read the calibration files and extract the required transformation matrices and focal length.

        calibration_path (str): The path to the directory containing the calibration files.

        tuple: A tuple containing the following elements:
            traimu2v (np.array): 4x4 transformation matrix from IMU to Velodyne coordinates.
            v2c (np.array): 4x4 transformation matrix from Velodyne to left camera coordinates.
            c2leftRGB (np.array): 4x4 transformation matrix from left camera to rectified left camera coordinates.
            c2rightRGB (np.array): 4x4 transformation matrix from right camera to rectified right camera coordinates.
            focal (float): Focal length of the left camera.
    c2c = []

    # Read and parse the camera-to-camera calibration file
    f = open(os.path.join(calibration_path, "calib_cam_to_cam.txt"), "r")
    cam_to_cam_str =
    [left_cam, right_cam] = cam_to_cam_str.split("S_02: ")[1].split("S_03: ")
    cam_to_cam_ls = [left_cam, right_cam]

    # Extract the transformation matrices for left and right cameras
    for i, cam_str in enumerate(cam_to_cam_ls):
        r_str, t_str = cam_str.split("R_0" + str(i + 2) + ": ")[1].split("\nT_0" + str(i + 2) + ": ")
        t_str = t_str.split("\n")[0]
        R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")])
        R = np.reshape(R, [3, 3])
        t = np.array([kitti_string_to_float(t) for t in t_str.split(" ")])
        Tr = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]])

        t_str_rect, s_rect_part = cam_str.split("\nT_0" + str(i + 2) + ": ")[1].split("\nS_rect_0" + str(i + 2) + ": ")
        s_rect_str, r_rect_part = s_rect_part.split("\nR_rect_0" + str(i + 2) + ": ")
        r_rect_str = r_rect_part.split("\nP_rect_0" + str(i + 2) + ": ")[0]
        R_rect = np.array([kitti_string_to_float(r) for r in r_rect_str.split(" ")])
        R_rect = np.reshape(R_rect, [3, 3])
        t_rect = np.array([kitti_string_to_float(t) for t in t_str_rect.split(" ")])
        Tr_rect = np.concatenate(
            [np.concatenate([R_rect, t_rect[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]]


    c2leftRGB = c2c[0]
    c2rightRGB = c2c[1]

    # Read and parse the Velodyne-to-camera calibration file
    f = open(os.path.join(calibration_path, "calib_velo_to_cam.txt"), "r")
    velo_to_cam_str =
    r_str, t_str = velo_to_cam_str.split("R: ")[1].split("\nT: ")
    t_str = t_str.split("\n")[0]
    R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")])
    R = np.reshape(R, [3, 3])
    t = np.array([kitti_string_to_float(r) for r in t_str.split(" ")])
    v2c = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]])

    # Read and parse the IMU-to-Velodyne calibration file
    f = open(os.path.join(calibration_path, "calib_imu_to_velo.txt"), "r")
    imu_to_velo_str =
    r_str, t_str = imu_to_velo_str.split("R: ")[1].split("\nT: ")
    R = np.array([kitti_string_to_float(r) for r in r_str.split(" ")])
    R = np.reshape(R, [3, 3])
    t = np.array([kitti_string_to_float(r) for r in t_str.split(" ")])
    imu2v = np.concatenate([np.concatenate([R, t[:, None]], axis=1), np.array([0.0, 0.0, 0.0, 1.0])[None, :]])

    # Extract the focal length of the left camera
    focal = kitti_string_to_float(left_cam.split("P_rect_02: ")[1].split()[0])

    return imu2v, v2c, c2leftRGB, c2rightRGB, focal

def tracking_calib_from_txt(calibration_path):
    Extract tracking calibration information from a KITTI tracking calibration file.

    This function reads a KITTI tracking calibration file and extracts the relevant
    calibration information, including projection matrices and transformation matrices
    for camera, LiDAR, and IMU coordinate systems.

        calibration_path (str): Path to the KITTI tracking calibration file.

        dict: A dictionary containing the following calibration information:
            P0, P1, P2, P3 (np.array): 3x4 projection matrices for the cameras.
            Tr_cam2camrect (np.array): 4x4 transformation matrix from camera to rectified camera coordinates.
            Tr_velo2cam (np.array): 4x4 transformation matrix from LiDAR to camera coordinates.
            Tr_imu2velo (np.array): 4x4 transformation matrix from IMU to LiDAR coordinates.
    # Read the calibration file
    f = open(calibration_path)
    calib_str =

    # Process the calibration data
    calibs = []
    for calibration in calib_str:
        calibs.append(np.array([kitti_string_to_float(val) for val in calibration.split()[1:]]))

    # Extract the projection matrices
    P0 = np.reshape(calibs[0], [3, 4])
    P1 = np.reshape(calibs[1], [3, 4])
    P2 = np.reshape(calibs[2], [3, 4])
    P3 = np.reshape(calibs[3], [3, 4])

    # Extract the transformation matrix for camera to rectified camera coordinates
    Tr_cam2camrect = np.eye(4)
    R_rect = np.reshape(calibs[4], [3, 3])
    Tr_cam2camrect[:3, :3] = R_rect

    # Extract the transformation matrices for LiDAR to camera and IMU to LiDAR coordinates
    Tr_velo2cam = np.concatenate([np.reshape(calibs[5], [3, 4]), np.array([[0.0, 0.0, 0.0, 1.0]])], axis=0)
    Tr_imu2velo = np.concatenate([np.reshape(calibs[6], [3, 4]), np.array([[0.0, 0.0, 0.0, 1.0]])], axis=0)

    return {
        "P0": P0,
        "P1": P1,
        "P2": P2,
        "P3": P3,
        "Tr_cam2camrect": Tr_cam2camrect,
        "Tr_velo2cam": Tr_velo2cam,
        "Tr_imu2velo": Tr_imu2velo,

def get_poses_calibration(basedir, oxts_path_tracking=None, selected_frames=None):
    Extract poses and calibration information from the KITTI dataset.

    This function processes the OXTS data (GPS/IMU) and extracts the
    pose information (translation and rotation) for each frame. It also
    retrieves the calibration information (transformation matrices and focal length)
    required for further processing.

        basedir (str): The base directory containing the KITTI dataset.
        oxts_path_tracking (str, optional): Path to the OXTS data file for tracking sequences.
            If not provided, the function will look for OXTS data in the basedir.
        selected_frames (list, optional): A list of frame indices to process.
            If not provided, all frames in the dataset will be processed.

        tuple: A tuple containing the following elements:
            poses (np.array): An array of 4x4 pose matrices representing the vehicle's
                position and orientation for each frame (IMU pose).
            calibrations (dict): A dictionary containing the transformation matrices
                and focal length obtained from the calibration files.
            focal (float): The focal length of the left camera.

    def oxts_to_pose(oxts):
        OXTS (Oxford Technical Solutions) data typically refers to the data generated by an Inertial and GPS Navigation System (INS/GPS) that is used to provide accurate position, orientation, and velocity information for a moving platform, such as a vehicle. In the context of the KITTI dataset, OXTS data is used to provide the ground truth for the vehicle's trajectory and 6 degrees of freedom (6-DoF) motion, which is essential for evaluating and benchmarking various computer vision and robotics algorithms, such as visual odometry, SLAM, and object detection.

        The OXTS data contains several important measurements:

        1. Latitude, longitude, and altitude: These are the global coordinates of the moving platform.
        2. Roll, pitch, and yaw (heading): These are the orientation angles of the platform, usually given in Euler angles.
        3. Velocity (north, east, and down): These are the linear velocities of the platform in the local navigation frame.
        4. Accelerations (ax, ay, az): These are the linear accelerations in the platform's body frame.
        5. Angular rates (wx, wy, wz): These are the angular rates (also known as angular velocities) of the platform in its body frame.

        In the KITTI dataset, the OXTS data is stored as plain text files with each line corresponding to a timestamp. Each line in the file contains the aforementioned measurements, which are used to compute the ground truth trajectory and 6-DoF motion of the vehicle. This information can be further used for calibration, data synchronization, and performance evaluation of various algorithms.
        poses = []

        def latlon_to_mercator(lat, lon, s):
            Converts latitude and longitude coordinates to Mercator coordinates (x, y) using the given scale factor.

            The Mercator projection is a widely used cylindrical map projection that represents the Earth's surface
            as a flat, rectangular grid, distorting the size of geographical features in higher latitudes.
            This function uses the scale factor 's' to control the amount of distortion in the projection.

                lat (float): Latitude in degrees, range: -90 to 90.
                lon (float): Longitude in degrees, range: -180 to 180.
                s (float): Scale factor, typically the cosine of the reference latitude.

                list: A list containing the Mercator coordinates [x, y] in meters.
            r = 6378137.0  # the Earth's equatorial radius in meters
            x = s * r * ((np.pi * lon) / 180)
            y = s * r * np.log(np.tan((np.pi * (90 + lat)) / 360))
            return [x, y]

        # Compute the initial scale and pose based on the selected frames
        if selected_frames is None:
            lat0 = oxts[0][0]
            scale = np.cos(lat0 * np.pi / 180)
            pose_0_inv = None
            oxts0 = oxts[selected_frames[0][0]]
            lat0 = oxts0[0]
            scale = np.cos(lat0 * np.pi / 180)

            pose_i = np.eye(4)

            [x, y] = latlon_to_mercator(oxts0[0], oxts0[1], scale)
            z = oxts0[2]
            translation = np.array([x, y, z])
            rotation = get_rotation(oxts0[3], oxts0[4], oxts0[5])
            pose_i[:3, :] = np.concatenate([rotation, translation[:, None]], axis=1)
            pose_0_inv = invert_transformation(pose_i[:3, :3], pose_i[:3, 3])

        # Iterate through the OXTS data and compute the corresponding pose matrices
        for oxts_val in oxts:
            pose_i = np.zeros([4, 4])
            pose_i[3, 3] = 1

            [x, y] = latlon_to_mercator(oxts_val[0], oxts_val[1], scale)
            z = oxts_val[2]
            translation = np.array([x, y, z])

            roll = oxts_val[3]
            pitch = oxts_val[4]
            heading = oxts_val[5]
            rotation = get_rotation(roll, pitch, heading)  # (3,3)

            pose_i[:3, :] = np.concatenate([rotation, translation[:, None]], axis=1)  # (4, 4)
            if pose_0_inv is None:
                pose_0_inv = invert_transformation(pose_i[:3, :3], pose_i[:3, 3])

            pose_i = np.matmul(pose_0_inv, pose_i)

        return np.array(poses)

    # If there is no tracking path specified, use the default path
    if oxts_path_tracking is None:
        oxts_path = os.path.join(basedir, "oxts/data")
        oxts = np.array([np.loadtxt(os.path.join(oxts_path, file)) for file in sorted(os.listdir(oxts_path))])
        calibration_path = os.path.dirname(basedir)

        calibrations = calib_from_txt(calibration_path)

        focal = calibrations[4]

        poses = oxts_to_pose(oxts)

    # If a tracking path is specified, use it to load OXTS data and compute the poses
        oxts_tracking = np.loadtxt(oxts_path_tracking)
        poses = oxts_to_pose(oxts_tracking)  # (n_frames, 4, 4)
        calibrations = None
        focal = None
        # Set velodyne close to z = 0
        # poses[:, 2, 3] -= 0.8

    # Return the poses, calibrations, and focal length
    return poses, calibrations, focal

def get_camera_poses_tracking(poses_velo_w_tracking, tracking_calibration, selected_frames, scene_no=None):
    exp = False
    camera_poses = []

    opengl2kitti = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])

    start_frame = selected_frames[0]
    end_frame = selected_frames[1]

    # Debug Camera offset
    if scene_no == 2:
        yaw = np.deg2rad(0.7)  ## Affects camera rig roll: High --> counterclockwise
        pitch = np.deg2rad(-0.5)  ## Affects camera rig yaw: High --> Turn Right
        # pitch = np.deg2rad(-0.97)
        roll = np.deg2rad(0.9)  ## Affects camera rig pitch: High -->  up
        # roll = np.deg2rad(1.2)
    elif scene_no == 1:
        if exp:
            yaw = np.deg2rad(0.3)  ## Affects camera rig roll: High --> counterclockwise
            pitch = np.deg2rad(-0.6)  ## Affects camera rig yaw: High --> Turn Right
            # pitch = np.deg2rad(-0.97)
            roll = np.deg2rad(0.75)  ## Affects camera rig pitch: High -->  up
            # roll = np.deg2rad(1.2)
            yaw = np.deg2rad(0.5)  ## Affects camera rig roll: High --> counterclockwise
            pitch = np.deg2rad(-0.5)  ## Affects camera rig yaw: High --> Turn Right
            roll = np.deg2rad(0.75)  ## Affects camera rig pitch: High -->  up
        yaw = np.deg2rad(0.05)
        pitch = np.deg2rad(-0.75)
        # pitch = np.deg2rad(-0.97)
        roll = np.deg2rad(1.05)
        # roll = np.deg2rad(1.2)

    cam_debug = np.eye(4)
    cam_debug[:3, :3] = get_rotation(roll, pitch, yaw)

    Tr_cam2camrect = tracking_calibration["Tr_cam2camrect"]
    Tr_cam2camrect = np.matmul(Tr_cam2camrect, cam_debug)
    Tr_camrect2cam = invert_transformation(Tr_cam2camrect[:3, :3], Tr_cam2camrect[:3, 3])
    Tr_velo2cam = tracking_calibration["Tr_velo2cam"]
    Tr_cam2velo = invert_transformation(Tr_velo2cam[:3, :3], Tr_velo2cam[:3, 3])

    camera_poses_imu = []
    for cam in camera_ls:
        Tr_camrect2cam_i = tracking_calibration["Tr_camrect2cam0" + str(cam)]
        Tr_cam_i2camrect = invert_transformation(Tr_camrect2cam_i[:3, :3], Tr_camrect2cam_i[:3, 3])
        # transform camera axis from kitti to opengl for nerf:
        cam_i_camrect = np.matmul(Tr_cam_i2camrect, opengl2kitti)
        cam_i_cam0 = np.matmul(Tr_camrect2cam, cam_i_camrect)
        cam_i_velo = np.matmul(Tr_cam2velo, cam_i_cam0)

        cam_i_w = np.matmul(poses_velo_w_tracking, cam_i_velo)

    for i, cam in enumerate(camera_ls):
        for frame_no in range(start_frame, end_frame + 1):

    return np.array(camera_poses)

def get_obj_pose_tracking(tracklet_path, poses_imu_tracking, calibrations, selected_frames, transform_matrix):
    Extracts object pose information from the KITTI motion tracking dataset for the specified frames.

    tracklet_path : str
        Path to the text file containing tracklet information.  A tracklet is a small sequence of object positions and orientations over time, often used in the context of object tracking and motion estimation in computer vision. In a dataset, a tracklet usually represents a single object's pose information across multiple consecutive frames. This information includes the object's position, orientation (usually as rotation around the vertical axis, i.e., yaw angle), and other attributes like object type, dimensions, etc.  In the KITTI dataset, tracklets are used to store and provide ground truth information about dynamic objects in the scene, such as cars, pedestrians, and cyclists.

    poses_imu_tracking : list of numpy arrays
        A list of 4x4 transformation matrices representing the poses (positions and orientations) of the ego vehicle A(the main vehicle equipped with sensors) in Inertial Measurement Unit (IMU) coordinates at different time instances (frames). Each matrix in the list corresponds to a single frame in the dataset.

    calibrations : dict
        Dictionary containing calibration information:
            - "Tr_velo2cam": 3x4 transformation matrix from Velodyne coordinates to camera coordinates.
            - "Tr_imu2velo": 3x4 transformation matrix from IMU coordinates to Velodyne coordinates.

    selected_frames : list of int
        List of two integers specifying the start and end frames to process.

    Object Instance instance.

    # Helper function to generate a rotation matrix around the y-axis
    def roty_matrix(roty):
        c = np.cos(roty)
        s = np.sin(roty)
        return np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])

    # Extract calibration data
    velo2cam = calibrations["Tr_velo2cam"]
    imu2velo = calibrations["Tr_imu2velo"]
    cam2velo = invert_transformation(velo2cam[:3, :3], velo2cam[:3, 3])
    velo2imu = invert_transformation(imu2velo[:3, :3], imu2velo[:3, 3])

    start_frame = selected_frames[0]
    end_frame = selected_frames[1]

    # Read tracklets from file
    f = open(tracklet_path)
    tracklets_str =

    track_ids, category_ids, lengths, heights, widths = [], [], [], [], []
    track2index = {}
    centers_list = []
    yaws_list = []
    timestamps_list = []

    # Extract metadata for all objects in the scene
    for tracklet in tracklets_str:
        tracklet = tracklet.split()
        if float(tracklet[1]) < 0:
        frame_id = int(tracklet[0])
        if frame_id < start_frame or frame_id > end_frame:

        track_id = int(tracklet[1])

        if not tracklet[2] in _sem2label:
        type = _sem2label[tracklet[2]]

        if not track_id in track2index:  # extract metadata
            height = float(tracklet[10])
            width = float(tracklet[11])
            length = float(tracklet[12])
            track2index[track_id] = len(track_ids)

            # create empty tensors for centers and yaws
            centers_list.append(torch.zeros((end_frame - start_frame + 1, 3), dtype=torch.float64))
            yaws_list.append(torch.zeros((end_frame - start_frame + 1, 1), dtype=torch.float64))
            timestamps_list.append(-1.0 * torch.ones((end_frame - start_frame + 1, 1), dtype=torch.float64))

        # extract object track
        index = track2index[track_id]
        timestamps_list[index][frame_id - start_frame] = float(frame_id - start_frame)

        # Initialize a 4x4 identity matrix for object pose in camera coordinates
        pose = np.array(
            [float(tracklet[13]), float(tracklet[14]), float(tracklet[15]), float(tracklet[16])]
        )  # x,y,z,yaw
        obj_pose_c = np.eye(4)
        obj_pose_c[:3, 3] = pose[:3]
        roty = pose[3]
        obj_pose_c[:3, :3] = roty_matrix(roty)

        # Transform object pose from camera coordinates to IMU coordinates
        obj_pose_imu = np.matmul(velo2imu, np.matmul(cam2velo, obj_pose_c))

        # Get the IMU pose for the corresponding frame
        pose_imu_w_frame_i = poses_imu_tracking[int(frame_id)]

        # Calculate the world pose of the object
        pose_obj_w_i = np.matmul(pose_imu_w_frame_i, obj_pose_imu)
        pose_obj_w_i = np.matmul(transform_matrix, pose_obj_w_i)
        # pose_obj_w_i[:, 3] *= scale_factor

        # Calculate the approximate yaw angle of the object in the world frame
        yaw_aprox = -np.arctan2(pose_obj_w_i[1, 0], pose_obj_w_i[0, 0])

        # store center and yaw
        centers_list[index][frame_id - start_frame] = torch.from_numpy(pose_obj_w_i[:3, 3])
        yaws_list[index][frame_id - start_frame] = yaw_aprox

    max_obj_per_frame = len(track_ids)

    instances_obj = Instances(
        centers=torch.stack(centers_list, dim=0),
        yaws=torch.stack(yaws_list, dim=0),
        timestamps=torch.stack(timestamps_list, dim=0),
        valid_mask=torch.stack(timestamps_list) > 0,

    return instances_obj

def get_scene_images_tracking(
    [start_frame, end_frame] = selected_frames
    # imgs = []
    img_name = []
    depth_name = []
    semantic_name = []
    normal_name = []

    left_img_path = os.path.join(os.path.join(tracking_path, "image_02"), sequence)
    right_img_path = os.path.join(os.path.join(tracking_path, "image_03"), sequence)

    left_normal_path = os.path.join(os.path.join(tracking_path, "normal_02"), sequence)
    right_normal_path = os.path.join(os.path.join(tracking_path, "normal_03"), sequence)

    if pred_normals:
        for frame_dir in [left_normal_path, right_normal_path]:
            for frame_no in range(len(os.listdir(left_normal_path))):
                if start_frame <= frame_no <= end_frame:
                    frame = sorted(os.listdir(frame_dir))[frame_no]
                    fname = os.path.join(frame_dir, frame)

    if use_depth:
        left_depth_path = os.path.join(os.path.join(tracking_path, "completion_02"), sequence)
        right_depth_path = os.path.join(os.path.join(tracking_path, "completion_03"), sequence)

    for frame_dir in [left_img_path, right_img_path]:
        for frame_no in range(len(os.listdir(left_img_path))):
            if start_frame <= frame_no <= end_frame:
                frame = sorted(os.listdir(frame_dir))[frame_no]
                fname = os.path.join(frame_dir, frame)
                # imgs.append(imageio.imread(fname))

    if use_depth:
        for frame_dir in [left_depth_path, right_depth_path]:
            for frame_no in range(len(os.listdir(left_depth_path))):
                if start_frame <= frame_no <= end_frame:
                    frame = sorted(os.listdir(frame_dir))[frame_no]
                    fname = os.path.join(frame_dir, frame)

    if use_semantic:
        frame_dir = os.path.join(semantic_path, "train", sequence)
        for _ in range(2):
            for frame_no in range(len(os.listdir(frame_dir))):
                if start_frame <= frame_no <= end_frame:
                    frame = sorted(os.listdir(frame_dir))[frame_no]
                    fname = os.path.join(frame_dir, frame)

    # imgs = (np.maximum(np.minimum(np.array(imgs), 255), 0) / 255.0).astype(np.float32)
    return img_name, depth_name, semantic_name, normal_name

def get_rays_np(H, W, focal, c2w):
    """Get ray origins, directions from a pinhole camera."""
    # Numpy Version
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing="xy")
    dirs = np.stack([(i - W * 0.5) / focal, -(j - H * 0.5) / focal, -np.ones_like(i)], -1)
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1)
    rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d))
    return rays_o, rays_d

def rotate_yaw(p, yaw):
    """Rotates p with yaw in the given coord frame with y being the relevant axis and pointing downwards

        p: 3D points in a given frame [N_pts, N_frames, 3]/[N_pts, N_frames, N_samples, 3]
        yaw: Rotation angle

        p: Rotated points [N_pts, N_frames, N_samples, 3]
    # p of size [batch_rays, n_obj, samples, xyz]
    if len(p.shape) < 4:
        # p = p[..., tf.newaxis, :]
        p = p.unsqueeze(-2)

    c_y = torch.cos(yaw)
    s_y = torch.sin(yaw)

    if len(c_y.shape) < 3:
        c_y = c_y.unsqueeze(-1)
        s_y = s_y.unsqueeze(-1)

    # c_y = tf.cos(yaw)[..., tf.newaxis]
    # s_y = tf.sin(yaw)[..., tf.newaxis]

    p_x = c_y * p[..., 0] - s_y * p[..., 2]
    p_y = p[..., 1]
    p_z = s_y * p[..., 0] + c_y * p[..., 2]

    # return tf.concat([p_x[..., tf.newaxis], p_y[..., tf.newaxis], p_z[..., tf.newaxis]], axis=-1)
    return[p_x.unsqueeze(-1), p_y.unsqueeze(-1), p_z.unsqueeze(-1)], dim=-1)

def scale_frames(p, sc_factor, inverse=False):
    """Scales points given in N_frames in each dimension [xyz] for each frame or rescales for inverse==True

        p: Points given in N_frames frames [N_points, N_frames, N_samples, 3]
        sc_factor: Scaling factor for new frame [N_points, N_frames, 3]
        inverse: Inverse scaling if true, bool

        p_scaled: Points given in N_frames rescaled frames [N_points, N_frames, N_samples, 3]
    # Take 150% of bbox to include shadows etc.
    dim = torch.tensor([1.0, 1.0, 1.0]) * sc_factor
    # dim = tf.constant([0.1, 0.1, 0.1]) * sc_factor

    half_dim = dim / 2
    scaling_factor = (1 / (half_dim + 1e-9)).unsqueeze(-2)
    # scaling_factor = (1 / (half_dim + 1e-9))[:, :, tf.newaxis, :]

    if not inverse:
        p_scaled = scaling_factor * p
        p_scaled = (1.0 / scaling_factor) * p

    return p_scaled

def get_all_ray_3dbox_intersection(rays_rgb, obj_meta_tensor, chunk, local=False, obj_to_remove=-100):
    """get all rays hitting an oject given 3D multi-object-tracking results of a sequence

        rays_rgb: All rays
        obj_meta_tensor: Metadata of all objects
        chunk: No. of rays processed at the same time
        local: Limit used memory if processed on a local machine with limited CPU/GPU resources
        obj_to_remove: If object should be removed from the set of rays

        rays_on_obj: Set of all rays hitting at least one object
        rays_to_remove: Set of all rays hitting an object, that should not be trained

    print("Removing object ", obj_to_remove)
    rays_on_obj = np.array([])
    rays_to_remove = np.array([])
    _batch_sz_inter = chunk if not local else 5000  # args.chunk
    _only_intersect_rays_rgb = rays_rgb[0][None]
    _n_rays = rays_rgb.shape[0]
    _n_obj = (rays_rgb.shape[1] - 3) // 2
    _n_bt = np.ceil(_n_rays / _batch_sz_inter).astype(np.int32)

    for i in range(_n_bt):
        _tf_rays_rgb = torch.from_numpy(rays_rgb[i * _batch_sz_inter : (i + 1) * _batch_sz_inter]).to(torch.float32)
        # _tf_rays_rgb = tf.cast(rays_rgb[i * _batch_sz_inter:(i + 1) * _batch_sz_inter], tf.float32)
        _n_bt_i = _tf_rays_rgb.shape[0]
        _rays_bt = [_tf_rays_rgb[:, 0, :], _tf_rays_rgb[:, 1, :]]
        _objs = torch.reshape(_tf_rays_rgb[:, 3:, :], (_n_bt_i, _n_obj, 6))
        # _objs = tf.reshape(_tf_rays_rgb[:, 3:, :], [_n_bt_i, _n_obj, 6])
        _obj_pose = _objs[..., :3]
        _obj_theta = _objs[..., 3]
        _obj_id = _objs[..., 4].to(torch.int64)
        # _obj_id = tf.cast(_objs[..., 4], tf.int32)
        _obj_meta = torch.index_select(obj_meta_tensor, 0, _obj_id.reshape(-1)).reshape(
            -1, _obj_id.shape[1], obj_meta_tensor.shape[1]
        # _obj_meta = tf.gather(obj_meta_tensor, _obj_id, axis=0)
        _obj_track_id = _obj_meta[..., 0].unsqueeze(-1)
        _obj_dim = _obj_meta[..., 1:4]

        box_points_insters = box_pts(_rays_bt, _obj_pose, _obj_theta, _obj_dim, one_intersec_per_ray=False)
        _mask = box_points_insters[8]
        if _mask is not None:
            if rays_on_obj.any():
                rays_on_obj = np.concatenate([rays_on_obj, np.array(i * _batch_sz_inter + (_mask[:, 0]).cpu().numpy())])
                rays_on_obj = np.array(i * _batch_sz_inter + _mask[:, 0].cpu().numpy())
            if obj_to_remove is not None:
                _hit_id = _obj_track_id[_mask]
                import pdb

                # _hit_id = tf.gather_nd(_obj_track_id, _mask)
                # bool_remove = tf.equal(_hit_id, obj_to_remove)
                bool_remove = np.equal(_hit_id, obj_to_remove)
                if any(bool_remove):
                    # _remove_mask = tf.gather_nd(_mask, tf.where(bool_remove))
                    _remove_mask = np.array(_mask[:, 0])[np.where(np.equal(_hit_id, obj_to_remove))[0]]
                    if rays_to_remove.any():
                        rays_to_remove = np.concatenate([rays_to_remove, np.array(i * _batch_sz_inter + _remove_mask)])
                        rays_to_remove = np.array(i * _batch_sz_inter + _remove_mask)

    return rays_on_obj, rays_to_remove, box_points_insters

class MarsKittiDataParserConfig(DataParserConfig):
    """nerual scene graph dataset parser config"""

    _target: Type = field(default_factory=lambda: MarsKittiDataparser)
    """target class to instantiate"""
    data: Path = Path("data/kitti/training/image_02/0005")
    """Directory specifying location of data."""
    sequence_id: str = "0006"
    """Sequence ID of the KITTI data."""
    scale_factor: float = 1
    """How much to scale the camera origins by."""
    scene_scale: float = 1.0
    """How much to scale the region of interest by."""
    alpha_color: str = "white"
    """alpha color of background"""
    first_frame: int = 65
    """specifies the beginning of a sequence if not the complete scene is taken as Input"""
    last_frame: int = 120
    """specifies the end of a sequence"""
    box_scale: float = 1.5
    """Maximum scale for bboxes to include shadows"""
    near_plane: float = 0.5
    """specifies the distance from the last pose to the near plane"""
    far_plane: float = 150.0
    """specifies the distance from the last pose to the far plane"""
    netchunk: int = 1024 * 64
    """number of pts sent through network in parallel, decrease if running out of memory"""
    chunk: int = 1024 * 32
    """number of rays processed in parallel, decrease if running out of memory"""
    use_car_latents: bool = True
    car_object_latents_path: Optional[Path] = Path(f"pretrain/car_nerf/kitti/latent_codes{sequence_id[-2:]}.pt")
    """path of car object latent codes"""
    car_nerf_state_dict_path: Optional[Path] = Path("pretrain/car_nerf/kitti/epoch_670.ckpt")
    """path of car nerf state dicts"""
    use_depth: bool = True
    """whether the training loop contains depth"""
    split_setting: str = "reconstruction"
    use_semantic: bool = False
    """whether to use semantic information"""
    semantic_path: Optional[Path] = Path("")
    """path of semantic inputs"""
    semantic_mask_classes: List[str] = field(default_factory=lambda: [])
    """semantic classes that do not generate gradient to the background model"""
    storage_config: StorageConfig = StorageConfig()
    """machine specific configs."""
    pred_normals: bool = False
    """whether to use normals information"""

class MarsKittiDataparser(DataParser):
    """nerual scene graph kitti Dataset"""

    config: MarsKittiDataParserConfig

    def __init__(self, config: MarsKittiDataParserConfig):
        super().__init__(config=config) Path =
        self.scale_factor: float = config.scale_factor
        self.alpha_color = config.alpha_color
        self.selected_frames = [config.first_frame, config.last_frame]
        self.use_time = False
        self.remove = -1
        self.near = config.near_plane
        self.far = config.far_plane
        self.netchunk = config.netchunk
        self.chunk = config.chunk
        self.use_semantic = config.use_semantic
        self.semantic_path: Path = config.storage_config.datapath_dict["KITTI-MOT-home"] / "panoptic_maps"
        self.pred_normals = config.pred_normals

        if "KITTI-MOT-home" in self.config.storage_config.datapath_dict:
   = (
                / "training"
                / "image_02"
                / self.config.sequence_id

        if "CarNeRF-latents" in self.config.storage_config.datapath_dict:
            self.config.car_object_latents_path = (
                / "kitti_mot"
                / "latents"
                / f"latent_codes{self.config.sequence_id[-2:]}.pt"
            assert self.config.car_object_latents_path.exists()

        if "CarNeRF-pretrained-model" in self.config.storage_config.datapath_dict:
            self.config.car_nerf_state_dict_path = (
                self.config.storage_config.datapath_dict["CarNeRF-pretrained-model"] / "epoch_670.ckpt"
            assert self.config.car_nerf_state_dict_path.exists()

    def _generate_dataparser_outputs(self, split="train"):
        visible_objects_ls = []
        objects_meta_ls = []
        semantic_meta = []

        kitti2vkitti = np.array(
            [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, -1.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
        if self.alpha_color is not None:
            alpha_color_tensor = get_color(self.alpha_color)
            alpha_color_tensor = None

        basedir = str(
        scene_id = basedir[-4:]  # check
        kitti_scene_no = int(scene_id)
        tracking_path = basedir[:-13]  # check
        calibration_path = os.path.join(os.path.join(tracking_path, "calib"), scene_id + ".txt")
        oxts_path_tracking = os.path.join(os.path.join(tracking_path, "oxts"), scene_id + ".txt")
        tracklet_path = os.path.join(os.path.join(tracking_path, "label_02"), scene_id + ".txt")

        tracking_calibration = tracking_calib_from_txt(calibration_path)
        focal_X = tracking_calibration["P2"][0, 0]
        focal_Y = tracking_calibration["P2"][1, 1]
        poses_imu_w_tracking, _, _ = get_poses_calibration(basedir, oxts_path_tracking)  # (n_frames, 4, 4) imu pose

        tr_imu2velo = tracking_calibration["Tr_imu2velo"]
        tr_velo2imu = invert_transformation(tr_imu2velo[:3, :3], tr_imu2velo[:3, 3])
        poses_velo_w_tracking = np.matmul(poses_imu_w_tracking, tr_velo2imu)  # (n_frames, 4, 4) velodyne pose

        if self.use_semantic:
            semantics = pd.read_csv(
                os.path.join(self.semantic_path, "colors", scene_id + ".txt"),
                sep=" ",

        if self.use_semantic:
            semantics = semantics.loc[~semantics["Category"].isin(self.config.semantic_mask_classes)]
            semantic_meta = Semantics(
                colors=torch.tensor(semantics.iloc[:, 1:].values),

        # Get camera Poses   camare id: 02, 03
        for cam_i in range(2):
            transformation = np.eye(4)
            projection = tracking_calibration["P" + str(cam_i + 2)]  # rectified camera coordinate system -> image
            K_inv = np.linalg.inv(projection[:3, :3])
            R_t = projection[:3, 3]

            t_crect2c = np.matmul(K_inv, R_t)
            # t_crect2c = 1./projection[[0, 1, 2],[0, 1, 2]] * projection[:, 3]
            transformation[:3, 3] = t_crect2c
            tracking_calibration["Tr_camrect2cam0" + str(cam_i + 2)] = transformation

        sequ_frames = self.selected_frames

        cam_poses_tracking = get_camera_poses_tracking(
            poses_velo_w_tracking, tracking_calibration, sequ_frames, kitti_scene_no
        # cam_poses_tracking[..., :3, 3] *= self.scale_factor

        # Orients and centers the poses
        oriented = torch.from_numpy(np.array(cam_poses_tracking).astype(np.float32))  # (n_frames, 3, 4)
        oriented, transform_matrix = camera_utils.auto_orient_and_center_poses(
        )  # oriented (n_frames, 3, 4), transform_matrix (3, 4)
        row = torch.tensor([0, 0, 0, 1], dtype=torch.float32)
        zeros = torch.zeros(oriented.shape[0], 1, 4)
        oriented =[oriented, zeros], dim=1)
        oriented[:, -1] = row  # (n_frames, 4, 4)
        transform_matrix =[transform_matrix, row[None, :]], dim=0)  # (4, 4)
        cam_poses_tracking = oriented.numpy()
        transform_matrix = transform_matrix.numpy()
        image_filenames, depth_name, semantic_name, normal_name = get_scene_images_tracking(

        # Get Object poses
        instances_obj = get_obj_pose_tracking(
            tracklet_path, poses_imu_w_tracking, tracking_calibration, sequ_frames, transform_matrix
        # deleted visible_objects_, objects_meta_,

        # # Align Axis with vkitti axis
        poses = np.matmul(kitti2vkitti, cam_poses_tracking).astype(np.float32)
        poses[..., :3, 3] *= self.scale_factor
        instances_obj.centers[:, :, 2] *= -1
        instances_obj.centers[:, :, [0, 1, 2]] = instances_obj.centers[:, :, [0, 2, 1]]
        instances_obj.centers *= self.scale_factor
        instances_obj.width *= (
            self.config.box_scale * self.scale_factor
        )  # scale up the bbox to include shadows, etc.; scale bbox according to the scene scale factor
        instances_obj.length *= self.config.box_scale * self.scale_factor
        instances_obj.height *= self.scale_factor  # height is not multiplied by box scale

        N_obj = instances_obj.num_instances

        counts = np.arange(instances_obj.len_tracklet * 2).reshape(2, -1)
        frame_timestamps = np.concatenate(
            [np.arange(instances_obj.len_tracklet), np.arange(instances_obj.len_tracklet)]
        i_test = np.array([(idx + 1) % 4 == 0 for idx in counts[0]])
        i_test = np.concatenate((i_test, i_test))
        if self.config.split_setting == "reconstruction":
            i_train = np.ones(instances_obj.len_tracklet * 2, dtype=bool)
        elif self.config.split_setting == "nvs-75":
            i_train = ~i_test
        elif self.config.split_setting == "nvs-50":
            i_train = np.array([(idx + 1) % 4 > 1 for idx in counts[0]])
            i_train = np.concatenate((i_train, i_train))
        elif self.config.split_setting == "nvs-25":
            i_train = np.array([idx % 4 == 0 for idx in counts[0]])
            i_train = np.concatenate((i_train, i_train))
            raise ValueError("No such split method")

        counts = counts.reshape(-1)
        i_train = counts[i_train]
        train_timestamps = frame_timestamps[i_train]
        i_test = counts[i_test]
        test_timestamps = frame_timestamps[i_test]

        test_load_image = imageio.imread(image_filenames[0])
        image_height, image_width = test_load_image.shape[:2]
        cx, cy = image_width / 2.0, image_height / 2.0

        if split == "train":
            indices = i_train
            timestamps = train_timestamps
        elif split == "val":
            indices = i_test
            timestamps = test_timestamps
        elif split == "test":
            indices = i_test
            timestamps = test_timestamps
            raise ValueError(f"Unknown dataparser split {split}")

        image_filenames = [image_filenames[i] for i in indices]
        depth_filenames = [depth_name[i] for i in indices] if self.config.use_depth else None
        normal_filenames = [normal_name[i] for i in indices] if self.pred_normals else None
        if self.use_semantic:
            semantic_meta.filenames = [semantic_name[i] for i in indices]
        poses = poses[indices]
        if self.config.use_car_latents:
            if not self.config.car_object_latents_path.exists():
                CONSOLE.print("[yello]Error: latents not exist")
            car_latents = torch.load(str(self.config.car_object_latents_path))
            track_car_latents = {}
            track_car_latents_mean = {}
            for k, idx in enumerate(car_latents["indices"]):
                if sequ_frames[0] <= idx["fid"] <= sequ_frames[1]:
                    if idx["oid"] in track_car_latents.keys():
                        track_car_latents[idx["oid"]] =
                            [track_car_latents[idx["oid"]], car_latents["latents"][k].unsqueeze(-1)], dim=-1
                        track_car_latents[idx["oid"]] = car_latents["latents"][k].unsqueeze(-1)
            for k in track_car_latents.keys():
                track_car_latents_mean[k] = track_car_latents[k][..., -1]

            car_latents = None

        aabb_scale = self.config.scene_scale
        scene_box = SceneBox(
                [[-aabb_scale, -aabb_scale, -aabb_scale], [aabb_scale, aabb_scale, aabb_scale]], dtype=torch.float32

        cameras = Cameras(
            camera_to_worlds=torch.from_numpy(poses[:, :3, :4]),

        dataparser_outputs = DataparserOutputs(
                "depth_filenames": depth_filenames,
                "obj_class": instances_obj.unique_category_ids,
                "scale_factor": self.scale_factor,
                "semantics": semantic_meta,
                "normal_filenames": normal_filenames,
                "instances": instances_obj,
                "frame_timestamps": timestamps,

        if self.config.use_car_latents:
                    "car_latents": track_car_latents_mean,
                    "car_nerf_state_dict_path": self.config.car_nerf_state_dict_path,

        print("finished data parsing")
        return dataparser_outputs

KittiParserSpec = DataParserSpecification(config=MarsKittiDataParserConfig())

Thank you! @wuzirui IIUC, the most important refactor here is using Object Instance instead of a dense obj nodes tensor, which is not mem efficient? But the implement and usage of Object Instance is unknown. So... there is nothing we can really do until the release of this version. If "minor changes should be made to plug it into the current mars release version", it would fall back to the high memory usage version? Correct me if I misunderstood. Thanks a lot.

wuzirui commented 8 months ago

Hi @Nplace-su! Thank you for your interest. The previous version is not mem efficient because it creates a huge tensor of [N, H, W, C] which contains exactly the same object trajectory info in each pixel. So in this version, we instead use a unique instance trajectory module, which is just a PyTorch dataclass. During the inference time, you just need to index the needed object info from the instance traj module with the frame timestamps. We will release the code, but there's still a lot of work before we can upload them. Hope this short description helps. If there's any problem, feel free to contact me.

JPenshornTAP commented 8 months ago

Hi @wuzirui and @Nplace-su! Thanks for your contributions and the updated code. So basically in the current version on the mars branch/ refactor all the data is stored and loaded into ram at the same time while it is more or less generator based in the updated version? Is this correct? Also, are you actively working on this? Thanks for your help!

Nplace-su commented 8 months ago

Hi @Nplace-su! Thank you for your interest. The previous version is not mem efficient because it creates a huge tensor of [N, H, W, C] which contains exactly the same object trajectory info in each pixel. So in this version, we instead use a unique instance trajectory module, which is just a PyTorch dataclass. During the inference time, you just need to index the needed object info from the instance traj module with the frame timestamps. We will release the code, but there's still a lot of work before we can upload them. Hope this short description helps. If there's any problem, feel free to contact me.

Thanks @wuzirui for your quick reply and the detailed information! I will try to somehow implement this optimization according to your description. And, hope the official code will be upload asap, XD.

wuzirui commented 8 months ago

Hi @wuzirui and @Nplace-su! Thanks for your contributions and the updated code. So basically in the current version on the mars branch/ refactor all the data is stored and loaded into ram at the same time while it is more or less generator based in the updated version? Is this correct? Also, are you actively working on this? Thanks for your help!

Hi, the updated version will still store all the trajectory info in RAM, the difference is that, the current version keeps a LOT of redundant data, while the new version remove the redundancy. Also, interpolation on time will be supported.

PolarisKyle commented 8 months ago

Hi,in the "obj_info" store all information about the objection in all train dataset, it cost too much RAM, how can i split it to some batch size like "torch.dataloader"?

wuzirui commented 8 months ago

Hi,in the "obj_info" store all information about the objection in all train dataset, it cost too much RAM, how can i split it to some batch size like "torch.dataloader"?

Hi, please see the updated parser I listed in this issue. It should help solve the RAM issue.

Nplace-su commented 8 months ago

For anyone who met the same issue, you can move the repeat_interleave and permute op to every batch instead of on the whole data. I can bypass the memory peak via this method, as a walkaround.

Yurri-hub commented 6 months ago

For anyone who met the same issue, you can move the repeat_interleave and permute op to every batch instead of on the whole data. I can bypass the memory peak via this method, as a walkaround.

Hi~I'm a little bit confused of how to move this action to every batch. Can you help me with the code you modified? Really thanks a lot. @Nplace-su

Nplace-su commented 6 months ago

For anyone who met the same issue, you can move the repeat_interleave and permute op to every batch instead of on the whole data. I can bypass the memory peak via this method, as a walkaround.

Hi~I'm a little bit confused of how to move this action to every batch. Can you help me with the code you modified? Really thanks a lot. @Nplace-su

just remove repeat_interleave and permute in the *, then in the and, change object_rays_info = self.train_dataset.metadata["obj_info"][c, y, x] to object_rays_info = self.train_dataset.metadata["obj_info"][c], and do repeat_interleave and permute ops to object_rays_info in functions like next_evel_image(), etc, everytime before you use object_rays_info. Feel free to ask if you meet any problems. @Yurri-hub

Yurri-hub commented 6 months ago

For anyone who met the same issue, you can move the repeat_interleave and permute op to every batch instead of on the whole data. I can bypass the memory peak via this method, as a walkaround.

Hi~I'm a little bit confused of how to move this action to every batch. Can you help me with the code you modified? Really thanks a lot. @Nplace-su

just remove repeat_interleave and permute in the *, then in the and, change object_rays_info = self.train_dataset.metadata["obj_info"][c, y, x] to object_rays_info = self.train_dataset.metadata["obj_info"][c], and do repeat_interleave and permute ops to object_rays_info in functions like next_evel_image(), etc, everytime before you use object_rays_info. Feel free to ask if you meet any problems. @Yurri-hub

Thanks a lot. I'll try it and feed back soon.

Yurri-hub commented 6 months ago

Thanks to your help, I have successfully optimized the memory usage, and both the training and evaluation processes are functioning normally now. However, I haven't used any permute ops for object operations, probably because the order itself is already correct. Just like this: image image Once again, thank you very much for your assistance! @Nplace-su

Nplace-su commented 6 months ago

Thanks to your help, I have successfully optimized the memory usage, and both the training and evaluation processes are functioning normally now. However, I haven't used any permute ops for object operations, probably because the order itself is already correct. Just like this: image image Once again, thank you very much for your assistance! @Nplace-su

@Yurri-hub Nice, thanks for your information. Actually, I have no idea what the permute op is for, so I just keep it for safety LOL.(After a more careful look, you didn't use permute ops because you repeat_interleave on dim 0, which is equal to the original repeat_interleave on dim 1 and then permute)