MCZhi / DIPP

[TNNLS] Differentiable Integrated Prediction and Planning Framework for Urban Autonomous Driving
https://mczhi.github.io/DIPP/
197 stars 40 forks source link

RuntimeError #29

Closed git0929 closed 4 months ago

git0929 commented 6 months ago

Hello, thank you very much for your great work, but I have encountered this problem during the training period, I do not know how it is caused, could you please tell me, I would be very grateful.

Traceback (most recent call last): File "train.py", line 254, in model_training() File "train.py", line 211, in model_training val_loss, val_metrics = valid_epoch(valid_loader, predictor, planner, args.use_planning) File "train.py", line 110, in valid_epoch plans, predictions, scores, cost_function_weights = predictor(ego, neighbors, map_lanes, map_crosswalks) File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/home/hou/DIPP/model/predictor.py", line 236, in forward agent_agent = self.agent_agent(actors, actor_mask) File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/home/hou/DIPP/model/predictor.py", line 106, in forward output = self.interaction_net(inputs, src_key_padding_mask=mask) File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 238, in forward output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/hou/anaconda3/envs/DIPP/lib/python3.8/site-packages/torch/nn/modules/transformer.py", line 437, in forward return torch._transformer_encoder_layer_fwd( RuntimeError: Mask shape should match input shape; transformer_mask is not supported in the fallback case.

MCZhi commented 6 months ago

Hi, @git0929, thank you for your interest in our work. If you are using a newer version of Pytorch, please pass the argument enable_nested_tensor=False to self.interaction_net = nn.TransformerEncoder(encoder_layer, num_layers=2).