wenbowen123 / iros20-6d-pose-tracking

[IROS 2020] se(3)-TrackNet: Data-driven 6D Pose Tracking by Calibrating Image Residuals in Synthetic Domains
Other
384 stars 66 forks source link

About the assertion in `train.py` #54

Closed martinlyra closed 1 year ago

martinlyra commented 1 year ago

Starting with some context, I and my partner, Marcus, are students working on our master thesis with topic on object detection applications in ROS on our own robot. We used the README to guide us, so we may get a successful application on our problem, which includes our own 3d-scanned models that we wish to detect and track the poses of.

When I try to run the train.py after generating about 1000 training and 200 validation data points and creating the data pairs (297 training & 200 validation), the training fails due to an assertion:

$ python train.py 
output_path /home/marcusmartin/repos/iros20-6d-pose-tracking/train_output/
loaded dataset info from: /home/marcusmartin/repos/iros20-6d-pose-tracking/generated_data_pair/train_data_blender_DR/../dataset_info.yml
self.cam_K:
 [[1.066778e+03 0.000000e+00 3.129869e+02]
 [0.000000e+00 1.067487e+03 2.413109e+02]
 [0.000000e+00 0.000000e+00 1.000000e+00]]
making dataset... for train
#dataset: 297
self.trans_normalizer=0.03, self.rot_normalizer=0.08726646259971647
len(train_dataset)= 297
Computing mean std for n=10000
Traceback (most recent call last):
  File "/home/marcusmartin/repos/iros20-6d-pose-tracking/train.py", line 110, in <module>
    for i, (data, target, A_in_cams, B_in_cams, rgbA, rgbB, maskA, maskB) in enumerate(train_loader):
  File "/home/marcusmartin/miniconda3/envs/iros20/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/home/marcusmartin/miniconda3/envs/iros20/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
    return self._process_data(data)
  File "/home/marcusmartin/miniconda3/envs/iros20/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
    data.reraise()
  File "/home/marcusmartin/miniconda3/envs/iros20/lib/python3.10/site-packages/torch/_utils.py", line 644, in reraise
    raise exception
AssertionError: Caught AssertionError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/marcusmartin/miniconda3/envs/iros20/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/marcusmartin/miniconda3/envs/iros20/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/marcusmartin/miniconda3/envs/iros20/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/marcusmartin/repos/iros20-6d-pose-tracking/datasets.py", line 107, in __getitem__
    data, target, rgbA, rgbB, maskA, maskB = self.processData(rgbA,depthA,A_in_cam,rgbB,depthB,B_in_cam,maskB)
  File "/home/marcusmartin/repos/iros20-6d-pose-tracking/datasets.py", line 154, in processData
    assert (rot_label>=-1).all() and (rot_label<=1).all(),'root:\n{}\nrot_label\n{}\n A2B_in_cam_rot{}\n'.format(self.root,rot_label,A2B_in_cam_rot)
AssertionError: root:
/home/marcusmartin/repos/iros20-6d-pose-tracking/generated_data_pair/train_data_blender_DR
rot_label
[-0.34210497  0.45855399  1.42401367]
 A2B_in_cam_rot[[ 0.99149074 -0.12449399  0.03804475]
 [ 0.12330102  0.99184521  0.03224788]
 [-0.04174899 -0.02728239  0.99875556]]

I would like to understand what this means, and what we can do to avoid getting this error. We have been able to generate smaller datasets (5/2 and 50/20) and train those datasets with success.

wenbowen123 commented 1 year ago

https://github.com/wenbowen123/iros20-6d-pose-tracking/blob/fcf714d09bc9c3f711e2230b32dfb7f3b1884e86/datasets.py#LL149C14-L149C14 we are normalizing the rotations to the range [-1,1] by the self.rot_normalizer which was set by https://github.com/wenbowen123/iros20-6d-pose-tracking/blob/fcf714d09bc9c3f711e2230b32dfb7f3b1884e86/dataset_info.yml#L13 The normalizer means at test time, how much do you think will the max rotation degree be between neighboring frames.

In your case, you can increase the normalizer value if you believe you can have very large rotation motion (also harder/longer to train) or generate the synthetic data pair with smaller rotation difference.

martinlyra commented 1 year ago

The normalizer value I am looking for is in datasets.py and inside TrackDataset's constructor; rot_normalizer=5*pi/180, right?

wenbowen123 commented 1 year ago

where did you get rot_normalizer=5*pi/180, can you paste the line?

martinlyra commented 1 year ago

On the end of this constructor.

https://github.com/wenbowen123/iros20-6d-pose-tracking/blob/fcf714d09bc9c3f711e2230b32dfb7f3b1884e86/datasets.py#L51

When I checked out whether it was being set by the dataset_info.yml, I couldn't find max_rotation being used for the rot_normalizer

wenbowen123 commented 1 year ago

Hi, I just pushed a fix. The predict.py are still hard-coded to reproduce the previous results. But when you train on your own data, the train.py should now use the params from the dataset_info.yml

martinlyra commented 1 year ago

Awesome! Thanks for helping us with this.