Closed Millba closed 11 months ago
You don't need .avi
files for finetuning. I recommend doing it this way:
H36M_JOINT_TO_LABEL
here for seeing the orders) and 3 is (x, y, conf_score) for 2D sequence and (x, y, z) for 3D sequence. Challenges you might face:
keypoints
) with the following structure:
.
└── keypoints/
├── train/
│ ├── sequence_1_2D.npy
│ ├── sequence_1_3D.npy
│ ├── sequence_2_2D.npy
│ ├── sequence_2_3D.npy
│ └── ...
└── val/
├── sequence_1_2D.npy
├── sequence_1_3D.npy
├── sequence_2_2D.npy
├── sequence_2_3D.npy
└── ...
Then you need to create a torch dataset for them. In the dataset __init__
function you can do the following:
__init__
receives path to keypoints
and data_split (either 'train' or 'val'). Then you list all the numpy files in train or validation directory__getitem__
method of your dataset you simply receive an index from torch dataloader and retrieve it from those big lists you defined before. You can also do mirroring augmentation with probability of 50% (same as what we did here)Note: You have to normalize both 2D and 3D sequences for the training to be in value [-1, 1]. This helps the model to have stable output. This is the function used for MotionAGFormer normalization:
def normalize(keypoints, w, h, is_3d=False):
result = np.copy(keypoints)
result[..., :2] = keypoints[..., :2] / w * 2 - [1, h / w] # for width and height
if is_3d:
result[..., 2:] = keypoints[..., 2:] / w * 2 # for depth in 3D keypoints
return result
where w
and h
are width and height of the video that you can get it from those avi files.
Note: In the evaluation code you have to report MPJPE in millimeters. But after normalization the scale is different and is in range [-1, 1]. So you have to denormalize it back to the millimeters. For doing it, in your torch dataset you can define the following method:
def denormalize(self, keypoints, idx, is_3d=False):
h, w = self.data_list_camera[idx]
result = np.copy(keypoints)
result[..., :2] = (keypoints[..., :2] + np.array([1, h / w])) * w / 2
if is_3d:
result[..., 2:] = keypoints[..., 2:] * w / 2
return result
Note that I also assumed when you clip the videos into sequences of 243 frames, you also store a variable called data_list_camera
that stores width and height of each clip in the avi video (in case that different videos are captured with different camera formats).
The good news is, for one my course projects, I already did these things so you can use them with minor modifications. In your __init__
method of your class assuming that you store the 2D keypoints dictionary in data_2d
and 3D keypoints dictionary in data_3d
, and if n_frames=243
and stride=81
, we have:
def __init__(self, keypoints_path, data_split, n_frames=243, stride=81, ...):
...
self.data_list_2d, self.data_list_3d, self.data_list_camera = self.split_into_clips(data_2d, data_3d, n_frames, stride)
assert len(self.data_list_2d) == len(self.data_list_3d)
assert len(self.data_list_2d) == len(self.data_list_camera)
where split_into_clips
is defined as follows:
def split_into_clips(self, data_2d, data_3d, n_frames, stride):
data_list_2d, data_list_3d, data_list_camera = [], [], []
for sequence_name in data_2d:
keypoints_2d = data_2d[sequence_name]
keypoints_3d = data_3d[sequence_name]['keypoints']
res_h = data_3d[sequence_name]['res_h']
res_w = data_3d[sequence_name]['res_w']
keypoints_2d = self.normalize(keypoints_2d, res_w, res_h)
keypoints_3d = self.normalize(keypoints_3d, res_w, res_h, is_3d=True)
keypoints_2d = keypoints_2d[:keypoints_3d.shape[0]] # Make sure the lengths are equal
clips_2d, clips_3d = self.partition(keypoints_2d, keypoints_3d, n_frames, stride)
data_list_2d.extend(clips_2d)
data_list_3d.extend(clips_3d)
data_list_camera.extend([(res_h, res_w)] * len(clips_2d))
return data_list_2d, data_list_3d, data_list_camera
Note that it is expected that data_3d to have the following format:
{
'<SEQUENCE_NAME>': {
"keypoints": <NUMPY ARRAY>,
"res_h": <HEIGHT OF THE CLIP>,
"res_w" <WIDTH OF THE CLIP>
}
}
Then parition
method is defined as:
def partition(self, keypoints_2d, keypoints_3d, clip_length, stride):
if self.data_split == "test":
stride = clip_length
clips_2d, clips_3d = [], []
video_length = keypoints_2d.shape[0]
if video_length <= clip_length:
new_indices = self.resample(video_length, clip_length)
clips_2d.append(keypoints_2d[new_indices])
clips_3d.append(keypoints_3d[new_indices])
else:
start_frame = 0
while (video_length - start_frame) >= clip_length:
clips_2d.append(keypoints_2d[start_frame:start_frame + clip_length])
clips_3d.append(keypoints_3d[start_frame:start_frame + clip_length])
start_frame += stride
new_indices = self.resample(video_length - start_frame, clip_length) + start_frame
clips_2d.append(keypoints_2d[new_indices])
clips_3d.append(keypoints_3d[new_indices])
return clips_2d, clips_3d
And resample
method, responsible for extrapolation, is defined as:
@staticmethod
def resample(original_length, target_length):
"""
Adapted from https://github.com/Walter0807/MotionBERT/blob/main/lib/utils/utils_data.py#L68
Returns an array that has indices of frames. elements of array are in range (0, original_length -1) and
we have target_len numbers (So it interpolates the frames)
"""
even = np.linspace(0, original_length, num=target_length, endpoint=False)
result = np.floor(even)
result = np.clip(result, a_min=0, a_max=original_length - 1).astype(np.uint32)
return result
Final note: You can change the evaluate
function throughout training to work this way:
def evaluate(args, model, test_loader, device):
print("[INFO] Evaluation")
model.eval()
mpjpe_all, p_mpjpe_all = AverageMeter(), AverageMeter()
with torch.no_grad():
for x, y, indices in tqdm(test_loader):
batch_size = x.shape[0]
x = x.to(device)
if args.flip:
batch_input_flip = flip_data(x)
predicted_3d_pos_1 = model(x)
predicted_3d_pos_flip = model(batch_input_flip)
predicted_3d_pos_2 = flip_data(predicted_3d_pos_flip) # Flip back
predicted_3d_pos = (predicted_3d_pos_1 + predicted_3d_pos_2) / 2
else:
predicted_3d_pos = model(x)
if args.root_rel:
predicted_3d_pos[:, :, 0, :] = 0 # [N,T,17,3]
else:
y[:, 0, 0, 2] = 0
predicted_3d_pos = predicted_3d_pos.detach().cpu().numpy()
y = y.cpu().numpy()
denormalized_predictions = []
for i, prediction in enumerate(predicted_3d_pos):
prediction = test_loader.dataset.denormalize(prediction,
indices[i].item(), is_3d=True)
denormalized_predictions.append(prediction[None, ...])
denormalized_predictions = np.concatenate(denormalized_predictions)
# Root-relative Errors
predicted_3d_pos = denormalized_predictions - denormalized_predictions[..., 0:1, :]
y = y - y[..., 0:1, :]
mpjpe = calculate_mpjpe(predicted_3d_pos, y)
p_mpjpe = calculate_p_mpjpe(predicted_3d_pos, y)
mpjpe_all.update(mpjpe, batch_size)
p_mpjpe_all.update(p_mpjpe, batch_size)
print(f"Protocol #1 error (MPJPE): {mpjpe_all.avg} mm")
print(f"Protocol #2 error (P-MPJPE): {p_mpjpe_all.avg} mm")
return mpjpe_all.avg, p_mpjpe_all.avg
Note that in function above:
__getitem__
is already denormalized__getitem__
receives) in addition to the 2D and 3D sequence.Hope it helps! Cheers.
Hi. I'm trying to set up a dataset and train it based on this code, but I'm continuously encountering a problem with the evaluate
function. After loading a pretrained model and calling the evaluate function without any additional training, I noticed that when I print out the model output in batch units, the output for a certain batch appears as NaN values. Also, when I try to calculate MPJPE excluding that batch, the MPJPE values do not seem accurate (they appear excessively small as < 0.00001). Why might this issue be occurring?
Hi @elisha0904, I'm not sure about the reason behind the NaN value. But for your second question it's because the output of MotionAGFormer is normalized to be in range [-1, 1]. The normalization depends on width and height of the video frame when recorded the video. So having those two values, you can denormalize the output as follows:
def denormalize(self, sequence, height, width):
result = np.copy(sequence)
result[..., :2] = (result[..., :2] + np.array([1, height / width])) * width / 2
result[..., 2:] = result[..., 2:] * w / 2
return result
Make sure your input 2D dataset is also needed to be normalized before passing them to the model:
def normalize(sequence, width, height, is_3d=False):
result = np.copy(sequence)
result[..., :2] = sequence[..., :2] / width * 2 - [1, height / width]
if is_3d: # This is only required for training to have normalized 3d groundtruth
result[..., 2:] = sequence[..., 2:] / width * 2
return result
Another thing to note is that for Human3.6M dataset, followed by MotionBERT and LCN, after denormalization there's a variable called 2.5d_factor
that is multiplied. That's because Monocular 3D human pose estimation is an ill-posed problem because: A tall person away from the camera has the same 2D coordinate as a short person close to the camera. So based on Section 6.2.1 of the LCN paper:
That means for each input sequence, a parameter lambda is learned to make them scale invariant. i.e. persons with different heights are scaled to the same range and MotionAGFormer after denormalization outputs (lambda * 3D coordinates) instead of (3D coordinates). Therefore 2.5d_factor
is (1 / lambda) that is multiplied to change them to the same scale.
So in case that you want to use MotionAGFormer without fine-tuning, make sure to both normalize and denormalize the data and also compute the paramater lambda as explained in LCN paper.
Note: This 2.5d_factor
is only used for Human3.6M training. For MPI-INF-3DHP after denormalization, you have the 3D Coordinates without requiring any further computation.
so if I understand correctly, you're saying that before calculating evaluation metrics like MPJPE, I need to upscale (I'm not sure if 'upscaling' is the right term) the values by multiplying them with a 2.5d_factor?
Let me add some more details to my first question. I'm trying to fine-tune motionAGFormer using these two datasets:
I've preprocessed the data using the following code and have been trying to train it, taking guidance from previous answers. However, I keep encountering NaN values, and the training doesn't seem to progress correctly. Is the code I've created accurate for this purpose?
import os
import torch
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset
from utils.data import flip_data
from utils.learning import AverageMeter
from loss.pose3d import p_mpjpe as calculate_p_mpjpe
from loss.pose3d import mpjpe as calculate_mpjpe
class FitHscDataset3D(Dataset):
def __init__(self, keypoints_path, data_split, n_frames=243, stride=81, res_h=900, res_w=900, flip=False):
self.data_split = data_split
self.n_frames = n_frames
self.res_h, self.res_w = res_h, res_w
self.flip= flip
self.stride = stride if data_split == 'train' else n_frames
# Load 2D and 3D keypoints data
data_2d, data_3d = self.load_data(keypoints_path, data_split)
# Split data into clips and store them along with camera information
self.data_list_2d, self.data_list_3d, self.data_list_camera = self.split_into_clips(data_2d, data_3d, n_frames, stride)
# Validate the lengths of 2D and 3D data lists
assert len(self.data_list_2d) == len(self.data_list_3d)
assert len(self.data_list_2d) == len(self.data_list_camera)
def load_data(self, keypoints_path, data_split):
data_list_2d, data_list_3d = {}, {}
split_path = os.path.join(keypoints_path, data_split)
if not os.path.exists(split_path):
raise FileNotFoundError(f"Data split path does not exist: {split_path}")
for filename in os.listdir(split_path):
if filename.endswith('_2D.npy'):
sequence_name = filename.replace('_2D.npy', '')
keypoints_2d_file = os.path.join(split_path, filename)
keypoints_3d_file = os.path.join(split_path, sequence_name + '_3D.npy')
if not os.path.isfile(keypoints_2d_file) or not os.path.isfile(keypoints_3d_file):
print(f"Skipping missing file: {sequence_name}")
continue
try:
keypoints_2d = np.load(keypoints_2d_file)
keypoints_3d = np.load(keypoints_3d_file)
except Exception as e:
print(f"Error loading file {filename}: {e}")
continue
if keypoints_2d.ndim != 3 or keypoints_3d.ndim != 3:
print(f"Invalid data dimensions for sequence: {sequence_name}")
continue
data_list_2d[sequence_name] = keypoints_2d
data_list_3d[sequence_name] = {'keypoints': keypoints_3d, 'res_h': self.res_h, 'res_w': self.res_w}
if not data_list_2d or not data_list_3d:
print("Warning: Data lists are empty after loading.")
return data_list_2d, data_list_3d
def split_into_clips(self, data_2d, data_3d, n_frames, stride):
data_list_2d, data_list_3d, data_list_camera = [], [], []
for sequence_name in data_2d:
keypoints_2d = data_2d[sequence_name]
keypoints_3d = data_3d[sequence_name]['keypoints']
res_h = data_3d[sequence_name]['res_h']
res_w = data_3d[sequence_name]['res_w']
if keypoints_2d.shape[0] != keypoints_3d.shape[0]:
print(f"Warning: Mismatch in sequence length for {sequence_name}. Skipping sequence.")
continue
# Normalize keypoints
keypoints_2d = self.normalize(keypoints_2d, res_w, res_h)
keypoints_3d = self.normalize(keypoints_3d, res_w, res_h, is_3d=True)
# Partition into clips
clips_2d, clips_3d = self.partition(keypoints_2d, keypoints_3d, n_frames, stride)
data_list_2d.extend(clips_2d)
data_list_3d.extend(clips_3d)
data_list_camera.extend([(res_h, res_w)] * len(clips_2d))
return data_list_2d, data_list_3d, data_list_camera
def normalize(self, keypoints, w, h, is_3d=False):
result = np.copy(keypoints)
result[..., :2] = keypoints[..., :2] / w * 2 - [1, h / w] # for width and height
if is_3d:
result[..., 2:] = keypoints[..., 2:] / w * 2 # for depth in 3D keypoints
return result
def denormalize(self, keypoints, idx, is_3d=False):
h, w = self.data_list_camera[idx]
result = np.copy(keypoints)
result[..., :2] = (keypoints[..., :2] + np.array([1, h / w])) * w / 2
if is_3d:
result[..., 2:] = keypoints[..., 2:] * w / 2
return result
def partition(self, keypoints_2d, keypoints_3d, clip_length, stride):
if self.data_split == "val":
stride = clip_length
clips_2d, clips_3d = [], []
video_length = keypoints_2d.shape[0]
if video_length <= clip_length:
new_indices = self.resample(video_length, clip_length)
clips_2d.append(keypoints_2d[new_indices])
clips_3d.append(keypoints_3d[new_indices])
else:
start_frame = 0
while (video_length - start_frame) >= clip_length:
clips_2d.append(keypoints_2d[start_frame:start_frame + clip_length])
clips_3d.append(keypoints_3d[start_frame:start_frame + clip_length])
start_frame += stride
new_indices = self.resample(video_length - start_frame, clip_length) + start_frame
clips_2d.append(keypoints_2d[new_indices])
clips_3d.append(keypoints_3d[new_indices])
return clips_2d, clips_3d
def __len__(self):
return len(self.data_list_2d)
def __getitem__(self, index):
keypoints_2d = self.data_list_2d[index]
keypoints_3d = self.data_list_3d[index]
if self.flip and random.random() > 0.5:
keypoints_2d = self.flip_data(keypoints_2d)
keypoints_3d = self.flip_data(keypoints_3d)
keypoints_2d = torch.from_numpy(keypoints_2d).float()
keypoints_3d = torch.from_numpy(keypoints_3d).float()
if self.data_split == 'train': return keypoints_2d, keypoints_3d
else: return keypoints_2d, keypoints_3d, index
@staticmethod
def resample(original_length, target_length):
"""
Adapted from https://github.com/Walter0807/MotionBERT/blob/main/lib/utils/utils_data.py#L68
Returns an array that has indices of frames. elements of array are in range (0, original_length -1) and
we have target_len numbers (So it interpolates the frames)
"""
even = np.linspace(0, original_length, num=target_length, endpoint=False)
result = np.floor(even)
result = np.clip(result, a_min=0, a_max=original_length - 1).astype(np.uint32)
return result
Yes it either upsamples or downsamples.
I can't find any issue with the provided code. I recommend adding and assert
statement in __getitem__
before returning the items just to check whether there exists any NaN values. Then write a function like _test
below this class that creates an instance of this dataset and in a for loop simply iterates each sample just to make sure there isn't any NaN value here. Finally you can say if __name__ == '__main__'
and call this _test
function just to verify it in this file.
Thanks for your response. I've checked as you suggested, and it seems there's nothing wrong with the data. So, what I'm currently thinking is that there might be something lacking in the code for invoking and preprocessing the Dataset. (In other words, it might not be fully compatible with the model)
I have question about how the data preprocessing code I've created (the FitHscDataset3D
class mentioned above) differs from your method of dataset invocation and preprocessing.
In your code, you invoke the dataset using MotionDataset3D
, then load the data using torch's DataLoader
, and finally perform preprocessing using the DataReaderH36m
class. This DataReaderH36m
is used not only when initially calling the train/test data but also throughout the evaluation process. I'm curious about the exact role of this code.
What's the difference between our FitHscDataset3D
and your MotionDataset3D
& DataReaderH36m
?
@elisha0904 Since we're using the preprocessed data provided by MotionBERT paper (Click here to see their documentation) and they're also using the one preprocessed by LCN (apparently this is the preprocessing code), we used the DataReaderH36M
in preprocessing stage to split train/test and in evaluation code it is used for denormalization and also for figuring out which action corresponds to the truncated sequence (used for MPJPE per action).
I don't think it should be an issue for you.
Just to clarify: for a single batch of data, you're sure that the input values are in range [-1, 1] and there is not any NaN values yet when you give it to the MotionAGFormer, the output becomes NaN!?
Oh, you were right. The values in the batch are not in the range of [-1, 1].
However, as you can see in our FitHscDataset3D
class, we used a normalize function to adjust the range of values, but the processed values are not within the [-1, 1] range. There seems to be no issue with the normalize function we used, but I think we need to investigate further to determine the specific cause. Do you have any speculations regarding this issue?
@elisha0904 If the values are not in the range [-1, 1] then it seems that the width and height of the camera for recording the RGB videos is different than what you entered.
According to our dataset homepage, it is stated that the resolution is 900x900, and the same resolution is shown in the results when checked with the code below. How should I normalize this
@Millba for that specific data when the range is not in range [-1, 1] after normalization, you can verify the range of skeletons by taking the minimum and maximum of all joints in all the frames before the normalization. Verify to see whether it is in range [0, 900].
Hi, @SoroushMehraban. I'm thankful that you've been providing consistent responses. I've come to realize that the root cause of the issue is not in the code I created, based on your answers.
I've realized that my current issue stems from data preprocessing, and I'm currently investigating how the pkl files for Human3.6m were created on the LCN's GitHub repository. It appears that the preprocessed Human3.6m data used in MotionBERT and LCN has undergone a transformation to pixel coordinates for x and y values. However, my data was not in pixel coordinates.
The problem now I'm facing is that when I referred to LCN's preprocessing code, it seems to handle x and y coordinates correctly, but the values for the z-coordinate are strange. I'm trying to find out the root cause of this issue.
Do you happen to know how the preprocessing of Human3.6m data was carried out, especially regarding factors like 2.5d_factor
? (I also attempted to find a solution in MotionBERT's repository, but couldn't find a suitable one.)
@elisha0904 I'm afraid I couldn't find how exactly they converted the videos into sequences and I just used the preprocessed version same as MotionBERT. Regarding the 2.5d_factor
, I believe it is 1/lambda that I explained above.
Also based on the normalization, I don't think there's any constraint to force the z-value to be in range [-1, 1] since it is divided by width/2 so assuming the depth is more than half of the width, it can be more.
But I used the same normalization for the demo that I provided in the repository. It first extracts 2D using HRNet and using the same normalization that you also apply, it passes it to MotionAGFormer.
If by Thursday you couldn't solve the issue, you can send that batch of data that causes the issue to my email so that I can take a look at it: soroush.mehraban@mail.utoronto.ca
I've solved the problem. The issue was that I overlooked the fact that both the Human3.6m dataset (and the dataset I wanted to use for fine tuning) are in world coordinates.
So, I referred to the code used for data preprocessing on the LCN GitHub repository and was able to transform my data into the same format as the preprocessed Human3.6m pkl files used in LCN and MotionBERT, by converting world coordinates to camera coordinates and then to pixel coordinates. Finally, after applying the normalize function, the values are within the [-1, 1] range.
Thank you so much for your detailed responses.
I want to use the data I have for finetuning, but I'm not sure how to preprocess it in the right format. The types of data I have are as follows:
Video: .avi format videos filmed from 8 different directions for a single movement (8 files per movement). Image: A folder of images in .jpg format, cut frame by frame from the video. 2D Keypoint: CSV files storing frame-by-frame 2D keypoint values for each video (8 files per movement). 3D Keypoint: CSV file storing frame-by-frame 3D keypoint values for a single movement (1 file per movement). Annotation: JSON files containing information about each video.
training/ └── raw_data/ └── video/ └── CA01_1_camera0/ ├── Motion2-1.avi └── image/ └── CA01_1_camera0/ ├── 0.jpg ├── 1.jpg └── labeling_data/ └── 2d_keypoint/ ├── Motion2-1.csv └── 3d_keypoint/ ├── CA01_1.csv ├── CA01_2.csv └── annotation/ ├── CA01_1_camera0.json ├── CA01_1_camera1.json