zubair-irshad / CenterSnap

Pytorch code for ICRA'22 paper: "Single-Shot Multi-Object 3D Shape Reconstruction and Categorical 6D Pose and Size Estimation"
https://zubair-irshad.github.io/projects/CenterSnap.html
286 stars 47 forks source link

Error while trying to run finetune on real data #26

Open samanfahandezh opened 1 year ago

samanfahandezh commented 1 year ago

I'm trying to run the finetuning code:

./runner.sh net_train.py @configs/net_config_real_resume.txt --checkpoint \path\to\best\checkpoint

But I get this error: Can't pickle <class 'zstd.ZstdError'>

This happens when the code tries to load the checkpoint in common.py script: torch.load(hparams.checkpoint, map_localization='cpu')['state_dict']

zstandard and all the other dependencies are correctly installed on my machine, and the checkpoint path is correct. Seems there is an issue in parsing the checkpoint, but I don't know what it is. Any suggestion or help would be appreciated.

zubair-irshad commented 1 year ago

Can you please share the full error trace along with the actual command you are running?

Are you able to train fine on synthetic data and only get this error for finetuning?

samanfahandezh commented 1 year ago

This is the actual command that tried to run:

./runner.sh net_train.py @configs/net_config_real_resume.txt --checkpoint configs/epoch=48.ckpt

I copied epoch=48.ckpt under configs directory, but I tries it with the original place (from synthetic training) as well and it didn't work, same error.

I didn't try with synthetic data for finetuning, but that shouldn't be an issue, since the problem here is reading and parsing the checkpoint I think. And this the error that I'm getting while running the command:

Samples per epoch 17272

Steps per epoch 539

Target steps: 240000

Actual steps: 240394

Epochs: 446

Using model class from: /home/jovyan/CenterSnap/simnet/lib/net/models/panoptic_net.py

Restoring from checkpoint: configs/epoch=48.ckpt

Validation sanity check: 0it [00:00, ?it/s]Traceback (most recent call last):

File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 244, in _feed

obj = _ForkingPickler.dumps(obj)

File "/opt/conda/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps

cls(buf, protocol).dump(obj)

_pickle.PicklingError: Can't pickle <class 'zstd.ZstdError'>: import of module 'zstd' failed

Traceback (most recent call last):

File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 244, in _feed

obj = _ForkingPickler.dumps(obj)

File "/opt/conda/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps

cls(buf, protocol).dump(obj)

_pickle.PicklingError: Can't pickle <class 'zstd.ZstdError'>: import of module 'zstd' failed

Traceback (most recent call last):

File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 244, in _feed

obj = _ForkingPickler.dumps(obj)

File "/opt/conda/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps

cls(buf, protocol).dump(obj)

_pickle.PicklingError: Can't pickle <class 'zstd.ZstdError'>: import of module 'zstd' failed

Traceback (most recent call last):

File "/opt/conda/lib/python3.10/multiprocessing/queues.py", line 244, in _feed

obj = _ForkingPickler.dumps(obj)

File "/opt/conda/lib/python3.10/multiprocessing/reduction.py", line 51, in dumps

cls(buf, protocol).dump(obj)

_pickle.PicklingError: Can't pickle <class 'zstd.ZstdError'>: import of module 'zstd' failed

zubair-irshad commented 1 year ago

Thanks, can you share the config file as well and how are you setting the paths in config file?

Did you also generate the real data as described in the readme in the format required by our repo?

samanfahandezh commented 1 year ago

I don't think this error has anything to do with the dataset or config file, since it happens after parsing the config file and while trying to parse the checkpoint, but here is the config file for finetune run, I didn't change anything in that file. Also, the real data also is the same format from the Readme file. Actually, I'm using prepossessed Real data you provided. Do I need to do anything to make them ready for finetuning?

--max_steps=240000 --finetune_real=True --model_file=models/panoptic_net.py --model_name=res_fpn --output=results/CenterSnap_Finetune/Real --train_path=file://data/Real/train --train_batch_size=32 --train_num_workers=5 --val_path=file://data/Real/test --val_batch_size=32 --val_num_workers=5 --optim_learning_rate=0.0006 --optim_momentum=0.9 --optim_weight_decay=1e-4 --optim_poly_exp=0.9 --optim_warmup_epochs=1 --loss_seg_mult=1.0 --loss_depth_mult=1.0 --loss_vertex_mult=0.1 --loss_rotation_mult=0.1 --loss_heatmap_mult=100.0 --loss_latent_emb_mult=0.1 --loss_abs_pose_mult=0.1 --loss_z_centroid_mult=0.1 --wandb_name=NOCS_Real_Finetune

zubair-irshad commented 1 year ago

Thanks for answering my questions, did you mean that you downloaded the data from this link we provided and the untarred real data file in data/Real/train? Did you extract data in /data/Real/train or under CenterSnap original directory i.e. CenterSnap/data/Real/train? if your data is under /data/Real/train then you might have missed an additional forward slash in your file paths i.e. in that case, it should be file:///data/Real/train

If your file paths are correct, can you make two additional small snippets to reproduce your errors in standalone Python scripts? Unfortunately, I don't know any other reasons for the origination of this error except wrong file paths so if you have doubled checked those, I would 1. Try to load the checkpoint state dict like we do in our notebooks for inference and see if you are able to load it fine 2. try to load a small batch of data maybe just one pickle. zstd file and see if you are able to load and inspect that fine.

BTW how did you get the epoch=48 checkpoint? Is it the one that performed the best on the synthetic validation set after training on synthetic data from scratch?

samanfahandezh commented 1 year ago

I ran the training for 50 epochs from scratch using provided synthetic dataset. /data/Real/train is under the original directory /CenterSnap, so I think file://data/Real/train should be correct.

Is it possible that there is not enough memory for parsing the checkpoint when calling: torch.load(hparams.checkpoint, map_location='cpu')['state_dict'] ? If so, what should be the remedy?

The problem is that I'm using Kubeflow to run everything, and there are other considerations/parts of the code which prevents me to figure out what's the issue here.

zubair-irshad commented 1 year ago

I don't think GPU memory is the issue when loading checkpoint (since you have trained on synthetic already) and I haven't seen this error on my end. My guess with the zstd file was that data is not loaded properly but looks like your paths and everything else data-wise seem correct.

Unfortunately, I don't have any other insights, other than my above two recommendations which I can again mention below:

  1. Try to load the checkpoint state dict like we do in our notebooks for inference and see if you are able to load it fine 2. try to load a small batch of data maybe just one pickle. zstd file and see if you are able to load and inspect that fine.