joslefaure / HIT

Official Implementation of our WACV2023 paper: “Holistic Interaction Transformer Network for Action Detection”
https://arxiv.org/abs/2210.12686
55 stars 9 forks source link

TypeError: conv3d() received an invalid combination of arguments error on training and inference #44

Closed Chris009527 closed 6 months ago

Chris009527 commented 6 months ago

Hello author, when I use your pretrained model to run inference and run training code , both of them have error occurred: the following is train code part:

2024-03-13 14:00:53,800 hit.trainer INFO: Start training Traceback (most recent call last): File "train_net.py", line 255, in main() File "train_net.py", line 245, in main args.no_head) File "train_net.py", line 100, in train mem_active, File "/home/hscc/Desktop/chris/HIT/hit/engine/trainer.py", line 59, in do_train loss_dict, weight_dict, metric_dict, pooled_feature = model(slow_video, fast_video, boxes, objects, keypoints, mem_extras) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/detector/action_detector.py", line 20, in forward result, detector_losses, loss_weight, detector_metrics = self.roi_heads(slow_features, fast_features, boxes, objects, keypoints, extras, part_forward) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/roi_heads_3d.py", line 12, in forward result, loss_action, loss_weight, accuracy_action = self.action(slow_features, fast_features, boxes, objects, keypoints, extras, part_forward) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/action_head/action_head.py", line 44, in forward x, x_pooled, x_objects, x_keypoints, x_pose = self.feature_extractor(slow_features, fast_features, proposals, objects, keypoints, extras, part_forward) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/action_head/roi_action_feature_extractor.py", line 145, in forward ia_feature, res_person, res_object, res_keypoint = self.hit_structure(person_pooled, proposals, object_pooled, objects, hands_pooled, keypoints, memory_person, None, None, phase="rgb") File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/action_head/hit_structure.py", line 213, in forward mem_feature, phase) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/action_head/hit_structure.py", line 230, in _reduce_dim query = self.person_dim_reduce(person) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 590, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 586, in _conv_forward input, weight, bias, self.stride, self.padding, self.dilation, self.groups TypeError: conv3d() received an invalid combination of arguments - got (Tensor, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:

and folliwing is inference code part:

2024-03-13 14:04:58,243 hit.inference INFO: Start evaluation on jhmdb_val dataset(9121 videos). 2024-03-13 14:04:58,249 hit.inference INFO: Stage 1: extracting clip features. 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1141/1141 [10:33<00:00, 1.80it/s] 2024-03-13 14:15:31,295 hit.inference INFO: Stage 1 time: 0:10:33.046037 (0.06940533238902268 s / video per device, on 1 devices) 2024-03-13 14:15:31,296 hit.inference INFO: Stage 2: predicting with extracted feature. 0%| | 0/1141 [00:00<?, ?it/s] Traceback (most recent call last): File "test_net.py", line 95, in main() File "test_net.py", line 89, in main output_folder=output_folder, File "/home/hscc/Desktop/chris/HIT/hit/engine/inference.py", line 182, in inference predictions = compute_on_dataset(model, data_loader, device, logger, mem_active) File "/home/hscc/Desktop/chris/HIT/hit/engine/inference.py", line 139, in compute_on_dataset results_dict = compute_on_dataset_2stage(model, data_loader, device, logger) File "/home/hscc/Desktop/chris/HIT/hit/engine/inference.py", line 116, in compute_on_dataset_2stage output = model(None, None, None, None, extras=extras, part_forward=1) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/detector/action_detector.py", line 20, in forward result, detector_losses, loss_weight, detector_metrics = self.roi_heads(slow_features, fast_features, boxes, objects, keypoints, extras, part_forward) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/roi_heads_3d.py", line 12, in forward result, loss_action, loss_weight, accuracy_action = self.action(slow_features, fast_features, boxes, objects, keypoints, extras, part_forward) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/action_head/action_head.py", line 44, in forward x, x_pooled, x_objects, x_keypoints, x_pose = self.feature_extractor(slow_features, fast_features, proposals, objects, keypoints, extras, part_forward) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/action_head/roi_action_feature_extractor.py", line 145, in forward ia_feature, res_person, res_object, res_keypoint = self.hit_structure(person_pooled, proposals, object_pooled, objects, hands_pooled, keypoints, memory_person, None, None, phase="rgb") File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/action_head/hit_structure.py", line 213, in forward mem_feature, phase) File "/home/hscc/Desktop/chris/HIT/hit/modeling/roi_heads/action_head/hit_structure.py", line 230, in _reduce_dim query = self.person_dim_reduce(person) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 590, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/hscc/anaconda3/envs/hit/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 586, in _conv_forward input, weight, bias, self.stride, self.padding, self.dilation, self.groups TypeError: conv3d() received an invalid combination of arguments - got (Tensor, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:

hope you can tell me how to solve this problem thanks.

joslefaure commented 6 months ago

You are probably using a newer pytorch version than 1.7.1. if you don't want to downgrade, update all of the conv3d occurrences with named arguments. For instance nn.Conv3d(in_channels=xxx, out_channels=xxx...)