DecaYale / RNNPose

RNNPose: Recurrent 6-DoF Object Pose Refinement with Robust Correspondence Field Estimation and Pose Optimization, CVPR 2022
Apache License 2.0
160 stars 17 forks source link

Batch size > 1 #23

Open mqtjean opened 1 year ago

mqtjean commented 1 year ago

Hi, thank you for your work !

I am using your code and I manage to train on linemod dataset. However in the template config file the batch size is set to 1.

When I try a batch size > 1, I get this error after the neighbors calibration :

/opt/miniconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/container.py:435: UserWarning: Setting attributes on ParameterList is not supported.
  warnings.warn("Setting attributes on ParameterList is not supported.")
correspondences_2d3d
all input arrays must have the same shape
Traceback (most recent call last):
  File "/home/RNNPose//tools/train.py", line 775, in <module>
    fire.Fire()
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/fire/core.py", line 471, in _Fire
    target=component.__name__)
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/RNNPose//tools/train.py", line 240, in multi_proc_train
    args=( params,) )
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 199, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
    while not context.join():
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)
Exception: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/home/RNNPose/tools/train.py", line 275, in train_worker
    gpu_id=params.start_gpu_id+rank,
  File "/home/RNNPose/tools/train.py", line 670, in train
    raise e
  File "/home/RNNPose/tools/train.py", line 551, in train
    for example in dataloader:
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
    data = self._next_data()
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 475, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/RNNPose/data/preprocess.py", line 736, in collate_fn_descriptor_deepim
    ret = merge_batch(list_data)
  File "/home/RNNPose/data/preprocess.py", line 70, in merge_batch
    ret[key] = np.stack(elems, axis=0)
  File "<__array_function__ internals>", line 6, in stack
  File "/opt/miniconda3/envs/py37/lib/python3.7/site-packages/numpy/core/shape_base.py", line 427, in stack
    raise ValueError('all input arrays must have the same shape')
ValueError: all input arrays must have the same shape

It happens when stacking correspondences_2d3d : Because correspondences_2d3d is of different length for each image.

Is it something you have faced as well and do you recommend using a batch size > 1 ?

DecaYale commented 1 year ago

We always set batch_size =1 in our paper. If you want to change the batch size > 1, there are 2 possible solutions:

  1. Use DDP (https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) to distribute the computation of a batch to multiple GPUs while maintaining batchsize=1 on each GPU.
  2. You may need to modify our code slightly by adding some for-loops or padding to handle this.
mqtjean commented 1 year ago

Thank you very much for your answer, I will look into it !