traveller59 / second.pytorch

SECOND for KITTI/NuScenes object detection
MIT License
1.72k stars 722 forks source link

All nan values : runtime.step=1900, loss.cls_loss=nan, loss.cls_loss_rt=nan, loss.loc_loss=nan, loss.loc_loss_rt=nan, loss.loc_elem=[nan, nan, nan, nan, nan, nan, nan] #144

Closed chowkamlee81 closed 5 years ago

chowkamlee81 commented 5 years ago

when executing

python ./pytorch/train.py train --config_path=/home/ubuntu/LIDAR/Traveller59/second/configs/pointpillars/car/xyres_16.config --model_dir=../model_pytorch,

iam getting nan values below ... Kindly help

runtime.step=1900, runtime.steptime=0.1566, loss.cls_loss=nan, loss.cls_loss_rt=nan, loss.loc_loss=nan, loss.loc_loss_rt=nan, loss.loc_elem=[nan, nan, nan, nan, nan, nan, nan], loss.cls_pos_rt=nan, loss.cls_neg_rt=nan, loss.dir_rt=nan, rpn_acc=0.9963, pr.prec@10=0.0, pr.rec@10=0.0, pr.prec@30=0.0, pr.rec@30=0.0, pr.prec@50=0.0, pr.rec@50=0.0, pr.prec@70=0.0, pr.rec@70=0.0, pr.prec@80=0.0, pr.rec@80=0.0, pr.prec@90=0.0, pr.rec@90=0.0, pr.prec@95=0.0, pr.rec@95=0.0, misc.num_vox=10896, misc.num_pos=92, misc.num_neg=23658, misc.num_anchors=23883, misc.lr=0.0003174 runtime.step=1950, runtime.steptime=0.1747, loss.cls_loss=nan, loss.cls_loss_rt=nan, loss.loc_loss=nan, loss.loc_loss_rt=nan, loss.loc_elem=[nan, nan, nan, nan, nan, nan, nan], loss.cls_pos_rt=nan, loss.cls_neg_rt=nan, loss.dir_rt=nan, rpn_acc=0.9962, pr.prec@10=0.0, pr.rec@10=0.0, pr.prec@30=0.0, pr.rec@30=0.0, pr.prec@50=0.0, pr.rec@50=0.0, pr.prec@70=0.0, pr.rec@70=0.0, pr.prec@80=0.0, pr.rec@80=0.0, pr.prec@90=0.0, pr.rec@90=0.0, pr.prec@95=0.0, pr.rec@95=0.0, misc.num_vox=7509, misc.num_pos=96, misc.num_neg=11006, misc.num_anchors=11237, misc.lr=0.0003183 runtime.step=2000, runtime.steptime=0.1833, loss.cls_loss=nan, loss.cls_loss_rt=nan, loss.loc_loss=nan, loss.loc_loss_rt=nan, loss.loc_elem=[nan, nan, nan, nan, nan, nan, nan], loss.cls_pos_rt=nan, loss.cls_neg_rt=nan, loss.dir_rt=nan, rpn_acc=0.9962, pr.prec@10=0.0, pr.rec@10=0.0, pr.prec@30=0.0, pr.rec@30=0.0, pr.prec@50=0.0, pr.rec@50=0.0, pr.prec@70=0.0, pr.rec@70=0.0, pr.prec@80=0.0, pr.rec@80=0.0, pr.prec@90=0.0, pr.rec@90=0.0, pr.prec@95=0.0, pr.rec@95=0.0, misc.num_vox=13329, misc.num_pos=114, misc.num_neg=21720, misc.num_anchors=22004, misc.lr=0.0003193 runtime.step=2050, runtime.steptime=0.1915, loss.cls_loss=nan, loss.cls_loss_rt=nan, loss.loc_loss=nan, loss.loc_loss_rt=nan, loss.loc_elem=[nan, nan, nan, nan, nan, nan, nan], loss.cls_pos_rt=nan, loss.cls_neg_rt=nan, loss.dir_rt=nan, rpn_acc=0.9963, pr.prec@10=0.0, pr.rec@10=0.0, pr.prec@30=0.0, pr.rec@30=0.0, pr.prec@50=0.0, pr.rec@50=0.0, pr.prec@70=0.0, pr.rec@70=0.0, pr.prec@80=0.0, pr.rec@80=0.0, pr.prec@90=0.0, pr.rec@90=0.0, pr.prec@95=0.0, pr.rec@95=0.0, misc.num_vox=12526, misc.num_pos=106, misc.num_neg=19512, misc.num_anchors=19785, misc.lr=0.0003202

traveller59 commented 5 years ago

Could you please change your title? it looks terrible.

the 'nan' appears in first log or appears after some steps?

Please don't use relative model dir path. I will add code to check this in next update.

chowkamlee81 commented 5 years ago

Already model voxelnet-7750.tckpt has been dumped. But my tensorboard results look everything nill and no exponential decay in loss.

Im using 2 RTX2080 Ti GPU's. Kindly help

chowkamlee81 commented 5 years ago

traveller59

chowkamlee81 commented 5 years ago

Kindly help

traveller59 commented 5 years ago

could you provide the log.txt in model dir? consider using simple-inference.ipynb and check the result. if wrong, you can debug and find the problem module.

chowkamlee81 commented 5 years ago

log.txt

chowkamlee81 commented 5 years ago

I tried with my trained model and used simple-inference.ipynb for validation. It seems im not getting any bounding box of car.lite.config. Kindly suggest how to go ahead

traveller59 commented 5 years ago

you need to use my pretrained model with simple-inference to debug... you need to add print function to the forward method of VoxelNet to find problem when using simple-inference.

chowkamlee81 commented 5 years ago

Iam able to detect bounding boxes of car with your trained model for all of 3 types. But im not getting any result wrt my model. Kindly help. I followed the same procedure as you listed in the doc

traveller59 commented 5 years ago

Can you train with kitti dataset correctly? if you are using custom data, do you use the web visualization tool to check the boundbox?

chowkamlee81 commented 5 years ago

The problem was wrt GPU and i reformatted the system. It is working now. Thanks for your inputs.

dingfuzhou commented 5 years ago

@chowkamlee81 Hi I have faced the same problem with you. Would you please give me some details about how this problem happens and how did you solve this problem? Thanks very much in advance! Best,

eddyhkchiu commented 5 years ago

I also got the NaN issues when training with all.pp.lowa.config. And here are my fixes and workarounds:

  1. second/pytorch/core/box_torch_ops.py In all torch.sqrt() and torch.log() functions, add a small number eps to the input, for example: eps = 1e-8 diagonal = torch.sqrt(la**2 + wa**2 + eps) The gradient of those functions when the input is 0 could be infinity. And adding a small number eps can avoid this issue.

  2. second/pytorch/models/pointpillars.py, class PillarFeatureNet, forward() function: The following original code may trigger dividing by 0 if there is a pillar that does not have any point inside. points_mean = features[:, :, :3].sum(dim=1, keepdim=True) / num_voxels.type_as(features).view(-1, 1, 1) I use the following workaround: num_voxels_set_0_to_1 = num_voxels.clone() num_voxels_set_0_to_1[num_voxels_set_0_to_1==0] = 1 points_mean = features[:, :, :3].sum( dim=1, keepdim=True) / num_voxels_set_0_to_1.type_as(features).view(-1, 1, 1)

  3. second/pytorch/core/losses.py: _softmax_cross_entropy_with_logits() function: I subtract max values before feeding logits into nn.CrossEntropyLoss(), as follows: logits_max, _ = torch.max(logits, 1, keepdim=True) logits = logits - logits_max Without subtracting max values, CrossEntropyLoss may provide Inf or NaN values if original logits have large values.

  4. Avoid using batch_size 1 with BatchNorm1D (https://discuss.pytorch.org/t/nan-when-i-use-batch-normalization-batchnorm1d/322). I set evaluation batch_size to 2 in all.pp.lowa.config. In second/pytorch/train.py train() function, when the batch_size is 1, I skip that training iteration.

After I implemented the above fixes and workarounds, I did not see the NaN issues again. Hope those also help your case.

Sreeni1204 commented 4 years ago

Hello

How did you generate those graphs?

Could you please provide me some details or hints on generating them?

Hetali-Vekariya commented 2 years ago

Hello

How did you generate those graphs?

Could you please provide me some details or hints on generating them?

Hello @Sreeni1204,

You have to install tensorboard and tensorflow. pip install tensorboard and pip install tensorflow. then run the command python -m tensorboard.main --logdir=path/to/saved/trained/modeldir. Open http://localhost:6006/ in the web browser. you can see the graphs.