Henry1iu / TNT-Trajectory-Prediction

A Pytorch Implementation of TNT: Target-driveN Trajectory Prediction
487 stars 92 forks source link

got an error in torch_geometric #46

Open ImCabbage opened 1 year ago

ImCabbage commented 1 year ago

I got an error when trying apex for multi-gpu training. Any idea about the msg below? Traceback (most recent call last): File "train_lotus.py", line 117, in train(args.local_rank, args) File "trainlotus.py", line 58, in train = trainer.train(iter_epoch) File "/root/cabbage/decisiontraining/core/trainer.py", line 129, in train return self.iteration(epoch, self.train_loader) File "/root/cabbage/decisiontraining/core/vectornet_trainer.py", line 132, in iteration loss = self.compute_loss(data) File "/root/cabbage/decisiontraining/core/vectornet_trainer.py", line 261, in compute_loss out = self.model(data) File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, *kwargs) File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/apex/amp/_initialize.py", line 196, in new_fwd output = old_fwd(applier(args, input_caster), File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/apex/amp/_initialize.py", line 51, in applier return type(value)(applier(v, fn) for v in value) File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/apex/amp/_initialize.py", line 51, in return type(value)(applier(v, fn) for v in value) File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/apex/amp/_initialize.py", line 47, in applier return fn(value) File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/apex/amp/_initialize.py", line 35, in to_type return t.to(dtype) File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/torch_geometric/data/data.py", line 251, in to return self.apply( File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/torch_geometric/data/data.py", line 234, in apply store.apply(func, args) File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/torch_geometric/data/storage.py", line 163, in apply self[key] = recursive_apply(value, func) File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/torch_geometric/data/storage.py", line 523, in recursive_apply return func(data) File "/root/miniconda3/envs/vectornet/lib/python3.8/site-packages/torch_geometric/data/data.py", line 252, in lambda x: x.to(device=device, non_blocking=non_blocking), args) TypeError: to() received an invalid combination of arguments - got (non_blocking=bool, device=torch.dtype, ), but expected one of: