Thanks for your great work !
I am trying to apply your model to continue learning .Without modify your network and dataloader , i encounter following error .Since it raise this error inside of your lib function , and there are some forward pass already , i find it very hard to debug.... Could you please help me out?
epoch:0,loss:mse: 20.561 (18.647) kld: 2.017 (2.326) sample: 20.089 (10.814) total_loss: 42.667 (31.787)
epoch:0,loss:mse: 20.943 (37.872) kld: 2.045 (2.312) sample: 20.115 (24.269) total_loss: 43.103 (64.452)
epoch:0,loss:mse: 21.206 (0.451) kld: 2.075 (2.000) sample: 20.142 (0.421) total_loss: 43.423 (2.872)
epoch:0,loss:mse: 20.992 (4.669) kld: 2.075 (2.000) sample: 19.957 (4.490) total_loss: 43.024 (11.158)
epoch:0,loss:mse: 21.082 (12.970) kld: 2.136 (6.585) sample: 20.000 (12.838) total_loss: 43.218 (32.393)
epoch:0,loss:mse: 21.042 (16.777) kld: 2.331 (3.709) sample: 19.912 (16.551) total_loss: 43.286 (37.037)
Traceback (most recent call last):
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./main.py", line 189, in <module>
main()
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./main.py", line 162, in main
task_iter(task, num_devices, pop, generation_id, loop_id, exp_config)
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/mu2net_traj/main.py", line 83, in task_iter
train_loop(paths, ds_train, ds_validation,
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./mytrain.py", line 203, in train_loop
model_data = path.model()
File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer.py", line 596, in forward
self.inference(sample_num=self.loss_cfg['sample']['k'])
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer.py", line 607, in inference
self.future_decoder(self.data, mode=mode, sample_num=sample_num, autoregress=True, need_weights=need_weights)
File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer.py", line 425, in forward
self.decode_traj_ar(data, mode, context, pre_motion, pre_vel, pre_motion_scene_norm, z, sample_num, need_weights=need_weights)
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer.py", line 341, in decode_traj_ar
tf_out, attn_weights = self.tf_decoder(tf_in_pos, context, memory_mask=mem_mask, tgt_mask=tgt_mask, num_agent=data['agent_num'], need_weights=need_weights)
File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer_lib.py", line 746, in forward
output, self_attn_weights[i], cross_attn_weights[i] = mod(output, memory, tgt_mask=tgt_mask,
File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer_lib.py", line 644, in forward
tgt2, self_attn_weights = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
File "/mnt/petrelfs/tangxiaqiang/miniconda3/./lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer_lib.py", line 506, in forward
return agent_aware_attention(
File "/mnt/petrelfs/tangxiaqiang/code/trajnet/dev/./AgentFormer/model/agentformer_lib.py", line 177, in agent_aware_attention
raise NotImplementedError
NotImplementedError
Thanks for your great work ! I am trying to apply your model to continue learning .Without modify your network and dataloader , i encounter following error .Since it raise this error inside of your lib function , and there are some forward pass already , i find it very hard to debug.... Could you please help me out?