TaatiTeam / MotionAGFormer

Official implementation of the paper "MotionAGFormer: Enhancing 3D Pose Estimation with a Transformer-GCNFormer Network" (WACV 2024).
Apache License 2.0
140 stars 17 forks source link

Data preprocessing for finetuning using external data. #4

Closed Millba closed 11 months ago

Millba commented 11 months ago

I want to use the data I have for finetuning, but I'm not sure how to preprocess it in the right format. The types of data I have are as follows:

Video: .avi format videos filmed from 8 different directions for a single movement (8 files per movement). Image: A folder of images in .jpg format, cut frame by frame from the video. 2D Keypoint: CSV files storing frame-by-frame 2D keypoint values for each video (8 files per movement). 3D Keypoint: CSV file storing frame-by-frame 3D keypoint values for a single movement (1 file per movement). Annotation: JSON files containing information about each video.

training/ └── raw_data/ └── video/ └── CA01_1_camera0/ ├── Motion2-1.avi └── image/ └── CA01_1_camera0/ ├── 0.jpg ├── 1.jpg └── labeling_data/ └── 2d_keypoint/ ├── Motion2-1.csv └── 3d_keypoint/ ├── CA01_1.csv ├── CA01_2.csv └── annotation/ ├── CA01_1_camera0.json ├── CA01_1_camera1.json

SoroushMehraban commented 11 months ago

You don't need .avi files for finetuning. I recommend doing it this way:

  1. Write an script that converts csv files into numpy format. The numpy shape for each 2D/3D sequence should be (T, 17, 3) where T is number of frames in the whole sequence, 17 is the number of joints in Human3.6M format (See H36M_JOINT_TO_LABEL here for seeing the orders) and 3 is (x, y, conf_score) for 2D sequence and (x, y, z) for 3D sequence. Challenges you might face:
    • You might not have confidence scores for 2D keypoints -> In that case set them to be 1.
    • Yo might not have the exact joints. Like you might have left hip and right hip but not the center hip -> Try to estimate the missing joint by taking their mean (Or other estimation that makes more sense to you)
      1. After finishing the first step, it is expected to have a directory (let's call it keypoints) with the following structure:
        .
        └── keypoints/
        ├── train/
        │   ├── sequence_1_2D.npy
        │   ├── sequence_1_3D.npy
        │   ├── sequence_2_2D.npy
        │   ├── sequence_2_3D.npy
        │   └── ...
        └── val/
        ├── sequence_1_2D.npy
        ├── sequence_1_3D.npy
        ├── sequence_2_2D.npy
        ├── sequence_2_3D.npy
        └── ...

        Then you need to create a torch dataset for them. In the dataset __init__ function you can do the following:

        • The __init__ receives path to keypoints and data_split (either 'train' or 'val'). Then you list all the numpy files in train or validation directory
        • Next, you can read them all and store them in a dictionary where keys are the sequence name and values are the actual numpy ndarray.
        • Now the issue is that some sequences are more than 243 frames and some are shorter. So you need to split them into clips:
        • Sequences that are shorter than 243 frames can be extrapolated to become 243 frames.
        • This is doable simply by repeating the adjacent frames. Like instead of being 1, 2, 3, 4, ..., 80. Now it becomes 1, 1, 1, 2, 2, 2, 3, 3, 3 , ..., 80, 80, 80 such that total number of frames are 243 now.
        • Sequences that are longer than 243 frames can be divided into clips of 243 frames each.
    • A trick here that used in our project is to use concept of stride, that is set to 81. So instead of clip one to be from frame 0 to 242 and clip two to be from frame 243 to 485, now clip one is from 0 to 242 but second clip is from 81 to 323.
    • Note that by doing this we make the training data more than what it truly is and it doesn't make sense to do it in validation data. So in validation data stride should be 243 (no overlapping between sequences).
      • So by turning video sequences into clips, now the dictionary instead of being keys as sequence name and values as the whole sequence, now the value is list of clips.
      • Finally you can turn this dictionary into a large list. So eventually you'll have a large list for 2D keypoints and a list with exact same size for 3D keypoints.
      • In the __getitem__ method of your dataset you simply receive an index from torch dataloader and retrieve it from those big lists you defined before. You can also do mirroring augmentation with probability of 50% (same as what we did here)

Note: You have to normalize both 2D and 3D sequences for the training to be in value [-1, 1]. This helps the model to have stable output. This is the function used for MotionAGFormer normalization:

def normalize(keypoints, w, h, is_3d=False):
    result = np.copy(keypoints)
    result[..., :2] = keypoints[..., :2] / w * 2 - [1, h / w]   # for width and height
    if is_3d:
        result[..., 2:] = keypoints[..., 2:] / w * 2   # for depth in 3D keypoints
    return result

where w and h are width and height of the video that you can get it from those avi files.

Note: In the evaluation code you have to report MPJPE in millimeters. But after normalization the scale is different and is in range [-1, 1]. So you have to denormalize it back to the millimeters. For doing it, in your torch dataset you can define the following method:

def denormalize(self, keypoints, idx, is_3d=False):
    h, w = self.data_list_camera[idx]
    result = np.copy(keypoints)
    result[..., :2] = (keypoints[..., :2] + np.array([1, h / w])) * w / 2
    if is_3d:
        result[..., 2:] = keypoints[..., 2:] * w / 2
    return result

Note that I also assumed when you clip the videos into sequences of 243 frames, you also store a variable called data_list_camera that stores width and height of each clip in the avi video (in case that different videos are captured with different camera formats).

The good news is, for one my course projects, I already did these things so you can use them with minor modifications. In your __init__ method of your class assuming that you store the 2D keypoints dictionary in data_2d and 3D keypoints dictionary in data_3d, and if n_frames=243 and stride=81, we have:

def __init__(self, keypoints_path, data_split, n_frames=243, stride=81, ...):
   ...
   self.data_list_2d, self.data_list_3d, self.data_list_camera = self.split_into_clips(data_2d, data_3d, n_frames, stride)
   assert len(self.data_list_2d) == len(self.data_list_3d)
   assert len(self.data_list_2d) == len(self.data_list_camera)

where split_into_clips is defined as follows:

    def split_into_clips(self, data_2d, data_3d, n_frames, stride):
        data_list_2d, data_list_3d, data_list_camera = [], [], []
        for sequence_name in data_2d:
            keypoints_2d = data_2d[sequence_name]
            keypoints_3d = data_3d[sequence_name]['keypoints']

            res_h = data_3d[sequence_name]['res_h']
            res_w = data_3d[sequence_name]['res_w']

            keypoints_2d = self.normalize(keypoints_2d, res_w, res_h)
            keypoints_3d = self.normalize(keypoints_3d, res_w, res_h, is_3d=True)

            keypoints_2d = keypoints_2d[:keypoints_3d.shape[0]]  # Make sure the lengths are equal

            clips_2d, clips_3d = self.partition(keypoints_2d, keypoints_3d, n_frames, stride)

            data_list_2d.extend(clips_2d)
            data_list_3d.extend(clips_3d)
            data_list_camera.extend([(res_h, res_w)] * len(clips_2d))

        return data_list_2d, data_list_3d, data_list_camera

Note that it is expected that data_3d to have the following format:

{
   '<SEQUENCE_NAME>': {
                                             "keypoints":  <NUMPY ARRAY>,
                                             "res_h": <HEIGHT OF THE CLIP>,
                                             "res_w" <WIDTH OF THE CLIP>
                                         }
}

Then parition method is defined as:

    def partition(self, keypoints_2d, keypoints_3d, clip_length, stride):
        if self.data_split == "test":
            stride = clip_length

        clips_2d, clips_3d = [], []
        video_length = keypoints_2d.shape[0]
        if video_length <= clip_length:
            new_indices = self.resample(video_length, clip_length)
            clips_2d.append(keypoints_2d[new_indices])
            clips_3d.append(keypoints_3d[new_indices])
        else:
            start_frame = 0
            while (video_length - start_frame) >= clip_length:
                clips_2d.append(keypoints_2d[start_frame:start_frame + clip_length])
                clips_3d.append(keypoints_3d[start_frame:start_frame + clip_length])
                start_frame += stride
            new_indices = self.resample(video_length - start_frame, clip_length) + start_frame
            clips_2d.append(keypoints_2d[new_indices])
            clips_3d.append(keypoints_3d[new_indices])
        return clips_2d, clips_3d

And resample method, responsible for extrapolation, is defined as:

    @staticmethod
    def resample(original_length, target_length):
        """
        Adapted from https://github.com/Walter0807/MotionBERT/blob/main/lib/utils/utils_data.py#L68

        Returns an array that has indices of frames. elements of array are in range (0, original_length -1) and
        we have target_len numbers (So it interpolates the frames)
        """
        even = np.linspace(0, original_length, num=target_length, endpoint=False)
        result = np.floor(even)
        result = np.clip(result, a_min=0, a_max=original_length - 1).astype(np.uint32)
        return result

Final note: You can change the evaluate function throughout training to work this way:

def evaluate(args, model, test_loader, device):
    print("[INFO] Evaluation")
    model.eval()
    mpjpe_all, p_mpjpe_all = AverageMeter(), AverageMeter()
    with torch.no_grad():
        for x, y, indices in tqdm(test_loader):
            batch_size = x.shape[0]
            x = x.to(device)

            if args.flip:
                batch_input_flip = flip_data(x)
                predicted_3d_pos_1 = model(x)
                predicted_3d_pos_flip = model(batch_input_flip)
                predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip)  # Flip back
                predicted_3d_pos = (predicted_3d_pos_1 + predicted_3d_pos_2) / 2
            else:
                predicted_3d_pos = model(x)
            if args.root_rel:
                predicted_3d_pos[:, :, 0, :] = 0  # [N,T,17,3]
            else:
                y[:, 0, 0, 2] = 0

            predicted_3d_pos = predicted_3d_pos.detach().cpu().numpy()
            y = y.cpu().numpy()

            denormalized_predictions = []
            for i, prediction in enumerate(predicted_3d_pos):
                prediction = test_loader.dataset.denormalize(prediction,
                                                                indices[i].item(), is_3d=True)
                denormalized_predictions.append(prediction[None, ...])
            denormalized_predictions = np.concatenate(denormalized_predictions)

            # Root-relative Errors
            predicted_3d_pos = denormalized_predictions - denormalized_predictions[...,  0:1, :]
            y = y - y[..., 0:1, :]

            mpjpe = calculate_mpjpe(predicted_3d_pos, y)
            p_mpjpe = calculate_p_mpjpe(predicted_3d_pos, y)
            mpjpe_all.update(mpjpe, batch_size)
            p_mpjpe_all.update(p_mpjpe, batch_size)

    print(f"Protocol #1 error (MPJPE): {mpjpe_all.avg} mm")
    print(f"Protocol #2 error (P-MPJPE): {p_mpjpe_all.avg} mm")
    return mpjpe_all.avg, p_mpjpe_all.avg

Note that in function above:

Hope it helps! Cheers.

elisha0904 commented 11 months ago

Hi. I'm trying to set up a dataset and train it based on this code, but I'm continuously encountering a problem with the evaluate function. After loading a pretrained model and calling the evaluate function without any additional training, I noticed that when I print out the model output in batch units, the output for a certain batch appears as NaN values. Also, when I try to calculate MPJPE excluding that batch, the MPJPE values do not seem accurate (they appear excessively small as < 0.00001). Why might this issue be occurring?

SoroushMehraban commented 11 months ago

Hi @elisha0904, I'm not sure about the reason behind the NaN value. But for your second question it's because the output of MotionAGFormer is normalized to be in range [-1, 1]. The normalization depends on width and height of the video frame when recorded the video. So having those two values, you can denormalize the output as follows:

def denormalize(self, sequence, height, width):
    result = np.copy(sequence)
    result[..., :2] = (result[..., :2] + np.array([1, height / width])) * width / 2
    result[..., 2:] = result[..., 2:] * w / 2
    return result

Make sure your input 2D dataset is also needed to be normalized before passing them to the model:

def normalize(sequence, width, height, is_3d=False):
    result = np.copy(sequence)
    result[..., :2] = sequence[..., :2] / width * 2 - [1, height / width] 
    if is_3d:   # This is only required for training to have normalized 3d groundtruth
        result[..., 2:] = sequence[..., 2:] / width * 2 
    return result

Another thing to note is that for Human3.6M dataset, followed by MotionBERT and LCN, after denormalization there's a variable called 2.5d_factor that is multiplied. That's because Monocular 3D human pose estimation is an ill-posed problem because: A tall person away from the camera has the same 2D coordinate as a short person close to the camera. So based on Section 6.2.1 of the LCN paper: image That means for each input sequence, a parameter lambda is learned to make them scale invariant. i.e. persons with different heights are scaled to the same range and MotionAGFormer after denormalization outputs (lambda * 3D coordinates) instead of (3D coordinates). Therefore 2.5d_factor is (1 / lambda) that is multiplied to change them to the same scale.

So in case that you want to use MotionAGFormer without fine-tuning, make sure to both normalize and denormalize the data and also compute the paramater lambda as explained in LCN paper.

Note: This 2.5d_factor is only used for Human3.6M training. For MPI-INF-3DHP after denormalization, you have the 3D Coordinates without requiring any further computation.

elisha0904 commented 11 months ago

so if I understand correctly, you're saying that before calculating evaluation metrics like MPJPE, I need to upscale (I'm not sure if 'upscaling' is the right term) the values by multiplying them with a 2.5d_factor?

Let me add some more details to my first question. I'm trying to fine-tune motionAGFormer using these two datasets:

  1. HumanSC3D
  2. FIT3D

I've preprocessed the data using the following code and have been trying to train it, taking guidance from previous answers. However, I keep encountering NaN values, and the training doesn't seem to progress correctly. Is the code I've created accurate for this purpose?

import os
import torch
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset

from utils.data import flip_data
from utils.learning import AverageMeter
from loss.pose3d import p_mpjpe as calculate_p_mpjpe
from loss.pose3d import mpjpe as calculate_mpjpe

class FitHscDataset3D(Dataset):
    def __init__(self, keypoints_path, data_split, n_frames=243, stride=81, res_h=900, res_w=900, flip=False):
        self.data_split = data_split
        self.n_frames = n_frames
        self.res_h, self.res_w = res_h, res_w
        self.flip= flip
        self.stride = stride if data_split == 'train' else n_frames

        # Load 2D and 3D keypoints data
        data_2d, data_3d = self.load_data(keypoints_path, data_split)

        # Split data into clips and store them along with camera information
        self.data_list_2d, self.data_list_3d, self.data_list_camera = self.split_into_clips(data_2d, data_3d, n_frames,  stride)

        # Validate the lengths of 2D and 3D data lists
        assert len(self.data_list_2d) == len(self.data_list_3d)
        assert len(self.data_list_2d) == len(self.data_list_camera)

    def load_data(self, keypoints_path, data_split):
        data_list_2d, data_list_3d = {}, {}
        split_path = os.path.join(keypoints_path, data_split)

        if not os.path.exists(split_path):
            raise FileNotFoundError(f"Data split path does not exist: {split_path}")

        for filename in os.listdir(split_path):
            if filename.endswith('_2D.npy'):
                sequence_name = filename.replace('_2D.npy', '')
                keypoints_2d_file = os.path.join(split_path, filename)
                keypoints_3d_file = os.path.join(split_path, sequence_name + '_3D.npy')

                if not os.path.isfile(keypoints_2d_file) or not os.path.isfile(keypoints_3d_file):
                    print(f"Skipping missing file: {sequence_name}")
                    continue
                try:
                    keypoints_2d = np.load(keypoints_2d_file)
                    keypoints_3d = np.load(keypoints_3d_file)
                except Exception as e:
                    print(f"Error loading file {filename}: {e}")
                    continue

                if keypoints_2d.ndim != 3 or keypoints_3d.ndim != 3:
                    print(f"Invalid data dimensions for sequence: {sequence_name}")
                    continue

                data_list_2d[sequence_name] = keypoints_2d
                data_list_3d[sequence_name] = {'keypoints': keypoints_3d, 'res_h': self.res_h, 'res_w': self.res_w}

        if not data_list_2d or not data_list_3d:
            print("Warning: Data lists are empty after loading.")

        return data_list_2d, data_list_3d

    def split_into_clips(self, data_2d, data_3d, n_frames, stride):
        data_list_2d, data_list_3d, data_list_camera = [], [], []

        for sequence_name in data_2d:
            keypoints_2d = data_2d[sequence_name]
            keypoints_3d = data_3d[sequence_name]['keypoints']
            res_h = data_3d[sequence_name]['res_h']
            res_w = data_3d[sequence_name]['res_w']

            if keypoints_2d.shape[0] != keypoints_3d.shape[0]:
                print(f"Warning: Mismatch in sequence length for {sequence_name}. Skipping sequence.")
                continue

            # Normalize keypoints
            keypoints_2d = self.normalize(keypoints_2d, res_w, res_h)
            keypoints_3d = self.normalize(keypoints_3d, res_w, res_h, is_3d=True)

            # Partition into clips
            clips_2d, clips_3d = self.partition(keypoints_2d, keypoints_3d, n_frames, stride)

            data_list_2d.extend(clips_2d)
            data_list_3d.extend(clips_3d)
            data_list_camera.extend([(res_h, res_w)] * len(clips_2d))

        return data_list_2d, data_list_3d, data_list_camera

    def normalize(self, keypoints, w, h, is_3d=False):
        result = np.copy(keypoints)
        result[..., :2] = keypoints[..., :2] / w * 2 - [1, h / w]   # for width and height
        if is_3d:
            result[..., 2:] = keypoints[..., 2:] / w * 2   # for depth in 3D keypoints
        return result

    def denormalize(self, keypoints, idx, is_3d=False):
        h, w = self.data_list_camera[idx]
        result = np.copy(keypoints)
        result[..., :2] = (keypoints[..., :2] + np.array([1, h / w])) * w / 2
        if is_3d:
            result[..., 2:] = keypoints[..., 2:] * w / 2
        return result

    def partition(self, keypoints_2d, keypoints_3d, clip_length, stride):
        if self.data_split == "val":
            stride = clip_length

        clips_2d, clips_3d = [], []
        video_length = keypoints_2d.shape[0]
        if video_length <= clip_length:
            new_indices = self.resample(video_length, clip_length)
            clips_2d.append(keypoints_2d[new_indices])
            clips_3d.append(keypoints_3d[new_indices])
        else:
            start_frame = 0
            while (video_length - start_frame) >= clip_length:
                clips_2d.append(keypoints_2d[start_frame:start_frame + clip_length])
                clips_3d.append(keypoints_3d[start_frame:start_frame + clip_length])
                start_frame += stride
            new_indices = self.resample(video_length - start_frame, clip_length) + start_frame
            clips_2d.append(keypoints_2d[new_indices])
            clips_3d.append(keypoints_3d[new_indices])
        return clips_2d, clips_3d

    def __len__(self):
        return len(self.data_list_2d)

    def __getitem__(self, index):
        keypoints_2d = self.data_list_2d[index]
        keypoints_3d = self.data_list_3d[index]

        if self.flip and random.random() > 0.5:
            keypoints_2d = self.flip_data(keypoints_2d)
            keypoints_3d = self.flip_data(keypoints_3d)

        keypoints_2d = torch.from_numpy(keypoints_2d).float()
        keypoints_3d = torch.from_numpy(keypoints_3d).float()

        if self.data_split == 'train': return keypoints_2d, keypoints_3d
        else: return keypoints_2d, keypoints_3d, index

    @staticmethod
    def resample(original_length, target_length):
        """
        Adapted from https://github.com/Walter0807/MotionBERT/blob/main/lib/utils/utils_data.py#L68

        Returns an array that has indices of frames. elements of array are in range (0, original_length -1) and
        we have target_len numbers (So it interpolates the frames)
        """
        even = np.linspace(0, original_length, num=target_length, endpoint=False)
        result = np.floor(even)
        result = np.clip(result, a_min=0, a_max=original_length - 1).astype(np.uint32)
        return result
SoroushMehraban commented 11 months ago

Yes it either upsamples or downsamples. I can't find any issue with the provided code. I recommend adding and assert statement in __getitem__ before returning the items just to check whether there exists any NaN values. Then write a function like _test below this class that creates an instance of this dataset and in a for loop simply iterates each sample just to make sure there isn't any NaN value here. Finally you can say if __name__ == '__main__' and call this _test function just to verify it in this file.

elisha0904 commented 11 months ago

Thanks for your response. I've checked as you suggested, and it seems there's nothing wrong with the data. So, what I'm currently thinking is that there might be something lacking in the code for invoking and preprocessing the Dataset. (In other words, it might not be fully compatible with the model)

I have question about how the data preprocessing code I've created (the FitHscDataset3D class mentioned above) differs from your method of dataset invocation and preprocessing.

In your code, you invoke the dataset using MotionDataset3D, then load the data using torch's DataLoader, and finally perform preprocessing using the DataReaderH36m class. This DataReaderH36m is used not only when initially calling the train/test data but also throughout the evaluation process. I'm curious about the exact role of this code.

What's the difference between our FitHscDataset3D and your MotionDataset3D & DataReaderH36m?

SoroushMehraban commented 11 months ago

@elisha0904 Since we're using the preprocessed data provided by MotionBERT paper (Click here to see their documentation) and they're also using the one preprocessed by LCN (apparently this is the preprocessing code), we used the DataReaderH36M in preprocessing stage to split train/test and in evaluation code it is used for denormalization and also for figuring out which action corresponds to the truncated sequence (used for MPJPE per action). I don't think it should be an issue for you.

Just to clarify: for a single batch of data, you're sure that the input values are in range [-1, 1] and there is not any NaN values yet when you give it to the MotionAGFormer, the output becomes NaN!?

elisha0904 commented 11 months ago

Oh, you were right. The values in the batch are not in the range of [-1, 1].

However, as you can see in our FitHscDataset3D class, we used a normalize function to adjust the range of values, but the processed values are not within the [-1, 1] range. There seems to be no issue with the normalize function we used, but I think we need to investigate further to determine the specific cause. Do you have any speculations regarding this issue?

SoroushMehraban commented 11 months ago

@elisha0904 If the values are not in the range [-1, 1] then it seems that the width and height of the camera for recording the RGB videos is different than what you entered.

Millba commented 11 months ago

According to our dataset homepage, it is stated that the resolution is 900x900, and the same resolution is shown in the results when checked with the code below. How should I normalize this

SoroushMehraban commented 11 months ago

@Millba for that specific data when the range is not in range [-1, 1] after normalization, you can verify the range of skeletons by taking the minimum and maximum of all joints in all the frames before the normalization. Verify to see whether it is in range [0, 900].

elisha0904 commented 11 months ago

Hi, @SoroushMehraban. I'm thankful that you've been providing consistent responses. I've come to realize that the root cause of the issue is not in the code I created, based on your answers.

I've realized that my current issue stems from data preprocessing, and I'm currently investigating how the pkl files for Human3.6m were created on the LCN's GitHub repository. It appears that the preprocessed Human3.6m data used in MotionBERT and LCN has undergone a transformation to pixel coordinates for x and y values. However, my data was not in pixel coordinates.

The problem now I'm facing is that when I referred to LCN's preprocessing code, it seems to handle x and y coordinates correctly, but the values for the z-coordinate are strange. I'm trying to find out the root cause of this issue.

Do you happen to know how the preprocessing of Human3.6m data was carried out, especially regarding factors like 2.5d_factor? (I also attempted to find a solution in MotionBERT's repository, but couldn't find a suitable one.)

SoroushMehraban commented 11 months ago

@elisha0904 I'm afraid I couldn't find how exactly they converted the videos into sequences and I just used the preprocessed version same as MotionBERT. Regarding the 2.5d_factor, I believe it is 1/lambda that I explained above.

Also based on the normalization, I don't think there's any constraint to force the z-value to be in range [-1, 1] since it is divided by width/2 so assuming the depth is more than half of the width, it can be more.

But I used the same normalization for the demo that I provided in the repository. It first extracts 2D using HRNet and using the same normalization that you also apply, it passes it to MotionAGFormer.

If by Thursday you couldn't solve the issue, you can send that batch of data that causes the issue to my email so that I can take a look at it: soroush.mehraban@mail.utoronto.ca

elisha0904 commented 11 months ago

I've solved the problem. The issue was that I overlooked the fact that both the Human3.6m dataset (and the dataset I wanted to use for fine tuning) are in world coordinates.

So, I referred to the code used for data preprocessing on the LCN GitHub repository and was able to transform my data into the same format as the preprocessed Human3.6m pkl files used in LCN and MotionBERT, by converting world coordinates to camera coordinates and then to pixel coordinates. Finally, after applying the normalize function, the values are within the [-1, 1] range.

Thank you so much for your detailed responses.