autonomousvision / transfuser

[PAMI'23] TransFuser: Imitation with Transformer-Based Sensor Fusion for Autonomous Driving; [CVPR'21] Multi-Modal Fusion Transformer for End-to-End Autonomous Driving
MIT License
1.1k stars 185 forks source link

Attention Map visualization for Geometric Fusion #237

Closed SRajasekar333 closed 1 week ago

SRajasekar333 commented 4 weeks ago

Hello,

As discussed in https://github.com/autonomousvision/transfuser/issues/230#issuecomment-2241859024 by @ap229997 , I am trying to visualize the attention map for Geometric Fusion, but I have a doubt on how to define the bev_points and cam_points to pass as arguments to the ._model (). So please guide me on how to resolve the issue.

Currently I defined as ,

rgb = data['rgb'].to(args.device, dtype=torch.float32)
lidar = data['lidar'].to(args.device, dtype=torch.float32)
target_point_image = data['target_point_image'].to(args.device, dtype=torch.float32)
ego_vel = data['speed'].to(args.device, dtype=torch.float32).reshape(-1, 1)

#get geometric projections to visualize geometric fusion
bev_points = data['bev_points'][-1].numpy().astype(np.uint8)
cam_points = data['cam_points'][-1].numpy().astype(np.uint8)
bs, _, _, n_corres, p_dim = bev_points.shape
bev_points = bev_points.reshape((bs, -1, n_corres, p_dim))
cam_points = cam_points.reshape((bs, -1, n_corres, p_dim))
_, _, _, attn_map = model.module._model(rgb, lidar, ego_vel, bev_points, cam_points)  #included bev_points and cam_points as additional arguments for geometric_fusion

but getting the error as,

bev_points = data['bev_points'][-1].numpy().astype(np.uint8)
KeyError: 'bev_points'

So @ap229997, @Kait0 please help me on how to resolve this issue.

Regards.

Kait0 commented 4 weeks ago

This is just a python error. 'bev_points' points is simply not part of the data dictionary. For python problems I recommend checking https://stackoverflow.com/ The bev_points are for example computed here.

SRajasekar333 commented 3 weeks ago

A small comment in the viz.py mentioned in the https://github.com/autonomousvision/transfuser/issues/230#issuecomment-2241859024,

As we unpack 5 variables from bev_points.shape,

bs, _, _, n_corres, p_dim = bev_points.shape

I think there could be a small correction in the line 96, 97,

bev_points = data['bev_points'][-1].numpy().astype(np.uint8)
cam_points = data['cam_points'][-1].numpy().astype(np.uint8)

to be updated as,

bev_points = data['bev_points'].to(args.device, dtype=torch.int64)
cam_points = data['cam_points'].to(args.device, dtype=torch.int64)

After this update, I could able to unpack attn_map, from,

_, _, _, attn_map = model.module._model(rgb, lidar, ego_vel, bev_points, cam_points)

Please correct me if we need to use only the last element of the data['bev_points'][-1] to get the bev_points

Regards.

ap229997 commented 2 weeks ago

Can you describe where the bev_points are coming from in the code? The code has likely changed and viz.py is not updated.

SRajasekar333 commented 2 weeks ago

From, https://github.com/autonomousvision/transfuser/blob/3852e65e070f2efcfaba18fbbdb48057096b8a9e/team_code_transfuser/data.py#L675 getting the curr_bev_points to extract data['bev_points']

ap229997 commented 2 weeks ago

Are you calling the function from here?

SRajasekar333 commented 2 weeks ago

Actually no, currently calling the function from data.py, so should I need to call the function from submission_agent.py ?

ap229997 commented 2 weeks ago

https://github.com/autonomousvision/transfuser/blob/3852e65e070f2efcfaba18fbbdb48057096b8a9e/team_code_transfuser/data.py#L675

This returns a numpy array for bev_points, as per your modifications data['bev_points'] is a torch tensor, is this conversion happening in the dataloader or somewhere else?

SRajasekar333 commented 1 week ago

Yes there is a dataloader,

dataloader = DataLoader(vdata, batch_size=args.batch_size, shuffle=False, num_workers=1, collate_fn=custom_collate_fn)

and extracting the data from,

for enum, data in enumerate(tqdm(dataloader)):

similar to desribed in viz.py

ap229997 commented 1 week ago

Your modifications should be fine then, viz.py was written with older version of the code.

SRajasekar333 commented 1 week ago

Ok, thanks for the clarification!