Henry1iu / TNT-Trajectory-Prediction

A Pytorch Implementation of TNT: Target-driveN Trajectory Prediction
487 stars 92 forks source link

lambda参数缺失 #22

Closed wangqf1997 closed 2 years ago

wangqf1997 commented 2 years ago

运行代码发现下述问题,缺少参数lambda1:

(TNT) acl@acl-MS-7C98:~/TNT$ python train_tnt.py --data_root dataset/interm_data --output_dir run/tnt/ --aux_loss --batch_size 64 --with_cuda --lr 0.0010 --warmup_epoch 30 --lr_update_freq 10 --lr_decay_rate 0.1 Processing... Loading Raw Data...: 100%|████████████████████████| 3561/3561 [00:02<00:00, 1212.07it/s]

[Argoverse]: The maximum of valid length is 266. [Argoverse]: The maximum of no. of candidates is 3357. Transforming the data to GraphData...: 100%|███████| 3561/3561 [00:33<00:00, 107.78it/s] Done! Processing... Loading Raw Data...: 100%|████████████████████████| 1197/1197 [00:00<00:00, 1249.45it/s]

[Argoverse]: The maximum of valid length is 272. [Argoverse]: The maximum of no. of candidates is 1356. Transforming the data to GraphData...: 100%|███████| 1197/1197 [00:09<00:00, 122.35it/s] Done! Traceback (most recent call last): File "train_tnt.py", line 129, in train(args.local_rank, args) File "train_tnt.py", line 38, in train trainer = TNTTrainer( File "/home/acl/TNT/core/trainer/tnt_trainer.py", line 104, in init self.model.lambda1, self.model.lambda2, self.model.lambda3, File "/home/acl/anaconda3/envs/TNT/lib/python3.8/site-packages/torch/nn/modules/module.py", line 947, in getattr raise AttributeError("'{}' object has no attribute '{}'".format( AttributeError: 'TNT' object has no attribute 'lambda1'

查看TNT model,发现只有注释中标注了这个参数(lambda1),请问该如何解决?

class TNT(nn.Module): def init(self, in_channels=8, horizon=30, num_subgraph_layers=3, num_global_graph_layer=1, subgraph_width=64, global_graph_width=64, with_aux=False, aux_width=64, target_pred_hid=64, m=50, motion_esti_hid=64, score_sel_hid=64, temperature=0.01, k=6, device=torch.device("cpu") ): """ TNT algorithm for trajectory prediction :param in_channels: int, the number of channels of the input node features :param horizon: int, the prediction horizon (prediction length) :param num_subgraph_layers: int, the number of subgraph layer :param num_global_graph_layer: the number of global interaction layer :param subgraph_width: int, the channels of the extrated subgraph features :param global_graph_width: int, the channels of extracted global graph feature :param with_aux: bool, with aux loss or not :param aux_width: int, the hidden dimension of aux recovery mlp :param n: int, the number of sampled target candidate :param target_pred_hid: int, the hidden dimension of target prediction :param m: int, the number of selected candidate :param motion_esti_hid: int, the hidden dimension of motion estimation :param score_sel_hid: int, the hidden dimension of score module :param temperature: float, the temperature when computing the score :param k: int, final output trajectories :param lambda1: float, the weight of candidate prediction loss :param lambda2: float, the weight of motion estimation loss :param lambda3: float, the weight of trajectory scoring lossa :param device: the device for computation :param multi_gpu: the multi gpu setting """ super(TNT, self).init() self.horizon = horizon self.m = m self.k = k

    self.with_aux = with_aux

    self.device = device

    # feature extraction backbone
    self.backbone = VectorNetBackbone(
        in_channels=in_channels,
        num_subgraph_layres=num_subgraph_layers,
        subgraph_width=subgraph_width,
        num_global_graph_layer=num_global_graph_layer,
        global_graph_width=global_graph_width,
        with_aux=with_aux,
        aux_mlp_width=aux_width,
        device=device
    )

    self.target_pred_layer = TargetPred(
        in_channels=global_graph_width,
        hidden_dim=target_pred_hid,
        m=m,
        device=device
    )
    self.motion_estimator = MotionEstimation(
        in_channels=global_graph_width,
        horizon=horizon,
        hidden_dim=motion_esti_hid
    )
    self.traj_score_layer = TrajScoreSelection(
        feat_channels=global_graph_width,
        horizon=horizon,
        hidden_dim=score_sel_hid,
        temper=temperature,
        device=self.device
    )
    self._init_weight()
Henry1iu commented 2 years ago

Hi,

最近我调整了一下trainer的结构, lambda的赋值漏了,现在已经补上.

Best, Jianbang