joslefaure / HIT

Official Implementation of our WACV2023 paper: “Holistic Interaction Transformer Network for Action Detection”
https://arxiv.org/abs/2210.12686
55 stars 9 forks source link

Error while training the model #7

Closed piyushandilya closed 1 year ago

piyushandilya commented 1 year ago

I am trying to train the model (JHMDB) with the weights SlowFast-ResNet50-4x16.pth (soft link created) at path HIT/data/models/pretrained_models/

However, when I try to train the model using python train_net.py --config-file "config_files/hitnet.yaml" I get an error saying that Conv3D received invalid arguments:

conv3d() received an invalid combination of arguments - got (Tensor, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, Parameter, Parameter, tuple, tuple, tuple, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, Parameter, Parameter, tuple, tuple, tuple, int)
  File "/home/stdf/HIT/hit/modeling/roi_heads/action_head/hit_structure.py", line 230, in _reduce_dim
    query = self.person_dim_reduce(person)
  File "/home/stdf/HIT/hit/modeling/roi_heads/action_head/hit_structure.py", line 213, in forward
    mem_feature, phase)
  File "/home/stdf/HIT/hit/modeling/roi_heads/action_head/roi_action_feature_extractor.py", line 145, in forward
    ia_feature, res_person, res_object, res_keypoint = self.hit_structure(person_pooled, proposals, object_pooled, objects, hands_pooled, keypoints, memory_person, None, None, phase="rgb")
  File "/home/stdf/HIT/hit/modeling/roi_heads/action_head/action_head.py", line 44, in forward
    x, x_pooled, x_objects, x_keypoints, x_pose = self.feature_extractor(slow_features, fast_features, proposals, objects, keypoints, extras, part_forward)
  File "/home/stdf/HIT/hit/modeling/roi_heads/roi_heads_3d.py", line 12, in forward
    result, loss_action, loss_weight, accuracy_action = self.action(slow_features, fast_features, boxes, objects, keypoints, extras, part_forward)
  File "/home/stdf/HIT/hit/modeling/detector/action_detector.py", line 20, in forward
    result, detector_losses, loss_weight, detector_metrics = self.roi_heads(slow_features, fast_features, boxes, objects, keypoints, extras, part_forward)
  File "/home/stdf/HIT/hit/engine/trainer.py", line 59, in do_train
    loss_dict, weight_dict, metric_dict, pooled_feature = model(slow_video, fast_video, boxes, objects, keypoints, mem_extras)
  File "/home/stdf/HIT/train_net.py", line 100, in train
    mem_active,
  File "/home/stdf/HIT/train_net.py", line 245, in main
    args.no_head)
  File "/home/stdf/HIT/train_net.py", line 255, in <module>
    main()
TypeError: conv3d() received an invalid combination of arguments - got (Tensor, Parameter, Parameter, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, tuple of ints padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, Parameter, Parameter, tuple, tuple, tuple, int)
 * (Tensor input, Tensor weight, Tensor bias, tuple of ints stride, str padding, tuple of ints dilation, int groups)
      didn't match because some of the arguments have invalid types: (Tensor, Parameter, Parameter, tuple, tuple, tuple, int)

This error is possibly originating at the reduce_person_dim function. Am I using the project incorrectly? Can you help me fix this?

piyushandilya commented 1 year ago

@joslefaure I'd really appreciate some help from you in this.

joslefaure commented 1 year ago

I have been looking into it, but I can't reproduce the issue. You can try to uncomment line 27 in train_net.py but it will help only if the error has to do with Cuda.

piyushandilya commented 1 year ago

@joslefaure Can you share your environment/requirements.txt ? Maybe it's related to using a different version of torch (or any other dependency).

piyushandilya commented 1 year ago

I changed my pytorch version to 1.7.1 and the error got resolved.

fengjingchehu commented 6 months ago

i wonder how u solve such question? @piyushandilya looking forward to your reply