TylerYep / torchinfo

View model summaries in PyTorch!
MIT License
2.48k stars 117 forks source link

Can torchinfo support BEVFusion (https://github.com/mit-han-lab/bevfusion) ? #267

Open dpan817 opened 1 year ago

dpan817 commented 1 year ago

Has anyone tried torchinfo with BEVFusion? I tried it, but it reported that "TypeError: Model contains a layer with an unsupported input or output type: <mmdet3d.ops.spconv.structure.SparseConvTensor object at 0x7f3d9a48fee0>, type: <class 'mmdet3d.ops.spconv.structure.SparseConvTensor'>"

TylerYep commented 1 year ago

Can you post the full code used to reproduce this error?

dpan817 commented 1 year ago

sorry for the later reply, as I worked on other issues in the past two weeks.

I debugged the code without torchinfo and get the parameters for the model forward, then compose the same parameters for summary() function call, but still failed.

the model forward parameters is : data_parallel_module

then I compose the parameters in summary() in tools/test.py

    if not distributed:
        model = MMDataParallel(model, device_ids=[0])
        print(f"Model:\n{model}")

        img_tensor=dataset[0].get('img').data
        img_tensor=img_tensor.unsqueeze(0)
        points_list=[dataset[0].get('points').data]
        camera2ego_tensor=dataset[0].get('camera2ego').data
        camera2ego_tensor=camera2ego_tensor.unsqueeze(0)
        lidar2ego_tensor=dataset[0].get('lidar2ego').data
        lidar2ego_tensor=lidar2ego_tensor.unsqueeze(0)
        lidar2camera_tensor=dataset[0].get('lidar2camera').data
        lidar2camera_tensor=lidar2camera_tensor.unsqueeze(0)
        lidar2image_tensor=dataset[0].get('lidar2image').data
        lidar2image_tensor=lidar2image_tensor.unsqueeze(0)
        camera_intrinsics_tensor=dataset[0].get('camera_intrinsics').data
        camera_intrinsics_tensor=camera_intrinsics_tensor.unsqueeze(0)
        camera2lidar_tensor=dataset[0].get('camera2lidar').data
        camera2lidar_tensor=camera2lidar_tensor.unsqueeze(0)
        img_aug_matrix_tensor=dataset[0].get('img_aug_matrix').data
        img_aug_matrix_tensor=img_aug_matrix_tensor.unsqueeze(0)
        lidar_aug_matrix_tensor=dataset[0].get('lidar_aug_matrix').data
        lidar_aug_matrix_tensor=lidar_aug_matrix_tensor.unsqueeze(0)
        metas_list=[dataset[0].get('metas').data]
        gt_masks_bev_tensor=torch.zeros(1, 6, 200, 200)
        gt_bboxes_3d_list=[dataset[0].get('gt_bboxes_3d').data]
        gt_labels_3d_list=[torch.tensor(dataset[0].get('gt_labels_3d').data,device='cuda:0')]

        args_dict = {
            'return_loss': False,
            'rescale': True,
            'img': img_tensor,
            'points': points_list,
            'gt_bboxes_3d': gt_bboxes_3d_list,
            'gt_labels_3d': gt_labels_3d_list,
            'gt_masks_bev': gt_masks_bev_tensor,
            'camera_intrinscis': camera_intrinsics_tensor,
            'camera2ego': camera2ego_tensor,
            'lidar2ego': lidar2ego_tensor,
            'lidar2camera': lidar2camera_tensor,
            'camera2lidar': camera2lidar_tensor,
            'lidar2image': lidar2image_tensor,
            'img_aug_matrix': img_aug_matrix_tensor,
            'lidar_aug_matrix': lidar_aug_matrix_tensor,
            'metas': metas_list
        }
        input_dict = { }

        summary(model, input_data=[input_dict, args_dict])

and the error is:

Traceback (most recent call last):
  File "tools/test.py", line 288, in <module>
    main()
  File "tools/test.py", line 250, in main
    summary(model, input_data=[input_dict, args_dict])
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 220, in summary
    x, correct_input_size = process_input(
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 246, in process_input
    correct_input_size = get_input_data_sizes(input_data)
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 496, in get_input_data_sizes
    return traverse_input_data(
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in traverse_input_data
    [traverse_input_data(d, action_fn, aggregate_fn) for d in data]
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in <listcomp>
    [traverse_input_data(d, action_fn, aggregate_fn) for d in data]
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 435, in traverse_input_data
    {
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 436, in <dictcomp>
    k: traverse_input_data(v, action_fn, aggregate_fn)
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in traverse_input_data
    [traverse_input_data(d, action_fn, aggregate_fn) for d in data]
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 448, in <listcomp>
    [traverse_input_data(d, action_fn, aggregate_fn) for d in data]
  File "/home/adlink/miniconda3/envs/bevfusion_mit/lib/python3.8/site-packages/torchinfo/torchinfo.py", line 447, in traverse_input_data
    result = aggregate(
  File "/home/adlink/Downloads/Lidar_AI_Solution/CUDA-BEVFusion/bevfusion/mmdet3d/core/bbox/structures/base_box3d.py", line 46, in __init__
    assert tensor.dim() == 2 and tensor.size(-1) == box_dim, tensor.size()
AssertionError: torch.Size([9, 1])