opendilab / InterFuser

[CoRL 2022] InterFuser: Safety-Enhanced Autonomous Driving Using Interpretable Sensor Fusion Transformer
Apache License 2.0
514 stars 42 forks source link

transformer tgt? #77

Open a1wj1 opened 9 months ago

a1wj1 commented 9 months ago

Hello, I don't understand the following code

` if self.end2end: tgt = self.query_pos_embed.repeat(bs, 1, 1) else: tgt = self.position_encoding( torch.ones((bs, 1, 20, 20), device=x["rgb"].device) ) tgt = tgt.flatten(2) tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2) tgt = tgt.permute(2, 0, 1)

    memory = self.encoder(features, mask=self.attn_mask)
    hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]

`

From the flowchart of the paper, the input decoder is either waypoints or the current image, but why is the target sequence here a learnable parameter torch.ones((bs, 1, 20, 20)?

deepcs233 commented 8 months ago

Hi! The (bs, 1, 20, 20) is used as the queries fot the traffic map.