Closed chowkamlee81 closed 5 years ago
Could you please change your title? it looks terrible.
the 'nan' appears in first log or appears after some steps?
Please don't use relative model dir path. I will add code to check this in next update.
Already model voxelnet-7750.tckpt has been dumped. But my tensorboard results look everything nill and no exponential decay in loss.
Im using 2 RTX2080 Ti GPU's. Kindly help
Kindly help
could you provide the log.txt in model dir? consider using simple-inference.ipynb and check the result. if wrong, you can debug and find the problem module.
I tried with my trained model and used simple-inference.ipynb for validation. It seems im not getting any bounding box of car.lite.config. Kindly suggest how to go ahead
you need to use my pretrained model with simple-inference to debug... you need to add print function to the forward method of VoxelNet to find problem when using simple-inference.
Iam able to detect bounding boxes of car with your trained model for all of 3 types. But im not getting any result wrt my model. Kindly help. I followed the same procedure as you listed in the doc
Can you train with kitti dataset correctly? if you are using custom data, do you use the web visualization tool to check the boundbox?
The problem was wrt GPU and i reformatted the system. It is working now. Thanks for your inputs.
@chowkamlee81 Hi I have faced the same problem with you. Would you please give me some details about how this problem happens and how did you solve this problem? Thanks very much in advance! Best,
I also got the NaN issues when training with all.pp.lowa.config. And here are my fixes and workarounds:
second/pytorch/core/box_torch_ops.py
In all torch.sqrt() and torch.log() functions, add a small number eps to the input, for example:
eps = 1e-8
diagonal = torch.sqrt(la**2 + wa**2 + eps)
The gradient of those functions when the input is 0 could be infinity. And adding a small number eps can avoid this issue.
second/pytorch/models/pointpillars.py, class PillarFeatureNet, forward() function:
The following original code may trigger dividing by 0 if there is a pillar that does not have any point inside.
points_mean = features[:, :, :3].sum(dim=1, keepdim=True) / num_voxels.type_as(features).view(-1, 1, 1)
I use the following workaround:
num_voxels_set_0_to_1 = num_voxels.clone()
num_voxels_set_0_to_1[num_voxels_set_0_to_1==0] = 1
points_mean = features[:, :, :3].sum( dim=1, keepdim=True) / num_voxels_set_0_to_1.type_as(features).view(-1, 1, 1)
second/pytorch/core/losses.py: _softmax_cross_entropy_with_logits() function:
I subtract max values before feeding logits into nn.CrossEntropyLoss(), as follows:
logits_max, _ = torch.max(logits, 1, keepdim=True)
logits = logits - logits_max
Without subtracting max values, CrossEntropyLoss may provide Inf or NaN values if original logits have large values.
Avoid using batch_size 1 with BatchNorm1D (https://discuss.pytorch.org/t/nan-when-i-use-batch-normalization-batchnorm1d/322). I set evaluation batch_size to 2 in all.pp.lowa.config. In second/pytorch/train.py train() function, when the batch_size is 1, I skip that training iteration.
After I implemented the above fixes and workarounds, I did not see the NaN issues again. Hope those also help your case.
Hello
How did you generate those graphs?
Could you please provide me some details or hints on generating them?
Hello
How did you generate those graphs?
Could you please provide me some details or hints on generating them?
Hello @Sreeni1204,
You have to install tensorboard and tensorflow. pip install tensorboard
and pip install tensorflow
.
then run the command python -m tensorboard.main --logdir=path/to/saved/trained/modeldir
. Open http://localhost:6006/
in the web browser. you can see the graphs.
when executing
python ./pytorch/train.py train --config_path=/home/ubuntu/LIDAR/Traveller59/second/configs/pointpillars/car/xyres_16.config --model_dir=../model_pytorch,
iam getting nan values below ... Kindly help
runtime.step=1900, runtime.steptime=0.1566, loss.cls_loss=nan, loss.cls_loss_rt=nan, loss.loc_loss=nan, loss.loc_loss_rt=nan, loss.loc_elem=[nan, nan, nan, nan, nan, nan, nan], loss.cls_pos_rt=nan, loss.cls_neg_rt=nan, loss.dir_rt=nan, rpn_acc=0.9963, pr.prec@10=0.0, pr.rec@10=0.0, pr.prec@30=0.0, pr.rec@30=0.0, pr.prec@50=0.0, pr.rec@50=0.0, pr.prec@70=0.0, pr.rec@70=0.0, pr.prec@80=0.0, pr.rec@80=0.0, pr.prec@90=0.0, pr.rec@90=0.0, pr.prec@95=0.0, pr.rec@95=0.0, misc.num_vox=10896, misc.num_pos=92, misc.num_neg=23658, misc.num_anchors=23883, misc.lr=0.0003174 runtime.step=1950, runtime.steptime=0.1747, loss.cls_loss=nan, loss.cls_loss_rt=nan, loss.loc_loss=nan, loss.loc_loss_rt=nan, loss.loc_elem=[nan, nan, nan, nan, nan, nan, nan], loss.cls_pos_rt=nan, loss.cls_neg_rt=nan, loss.dir_rt=nan, rpn_acc=0.9962, pr.prec@10=0.0, pr.rec@10=0.0, pr.prec@30=0.0, pr.rec@30=0.0, pr.prec@50=0.0, pr.rec@50=0.0, pr.prec@70=0.0, pr.rec@70=0.0, pr.prec@80=0.0, pr.rec@80=0.0, pr.prec@90=0.0, pr.rec@90=0.0, pr.prec@95=0.0, pr.rec@95=0.0, misc.num_vox=7509, misc.num_pos=96, misc.num_neg=11006, misc.num_anchors=11237, misc.lr=0.0003183 runtime.step=2000, runtime.steptime=0.1833, loss.cls_loss=nan, loss.cls_loss_rt=nan, loss.loc_loss=nan, loss.loc_loss_rt=nan, loss.loc_elem=[nan, nan, nan, nan, nan, nan, nan], loss.cls_pos_rt=nan, loss.cls_neg_rt=nan, loss.dir_rt=nan, rpn_acc=0.9962, pr.prec@10=0.0, pr.rec@10=0.0, pr.prec@30=0.0, pr.rec@30=0.0, pr.prec@50=0.0, pr.rec@50=0.0, pr.prec@70=0.0, pr.rec@70=0.0, pr.prec@80=0.0, pr.rec@80=0.0, pr.prec@90=0.0, pr.rec@90=0.0, pr.prec@95=0.0, pr.rec@95=0.0, misc.num_vox=13329, misc.num_pos=114, misc.num_neg=21720, misc.num_anchors=22004, misc.lr=0.0003193 runtime.step=2050, runtime.steptime=0.1915, loss.cls_loss=nan, loss.cls_loss_rt=nan, loss.loc_loss=nan, loss.loc_loss_rt=nan, loss.loc_elem=[nan, nan, nan, nan, nan, nan, nan], loss.cls_pos_rt=nan, loss.cls_neg_rt=nan, loss.dir_rt=nan, rpn_acc=0.9963, pr.prec@10=0.0, pr.rec@10=0.0, pr.prec@30=0.0, pr.rec@30=0.0, pr.prec@50=0.0, pr.rec@50=0.0, pr.prec@70=0.0, pr.rec@70=0.0, pr.prec@80=0.0, pr.rec@80=0.0, pr.prec@90=0.0, pr.rec@90=0.0, pr.prec@95=0.0, pr.rec@95=0.0, misc.num_vox=12526, misc.num_pos=106, misc.num_neg=19512, misc.num_anchors=19785, misc.lr=0.0003202