Open kamwoh opened 1 year ago
Based on your implementation, why does the slot attention iterate with the same q & k?
https://github.com/wbw520/BotCL/blob/3dde3ac20cdecd7eea8c4b7cb0e04e2bb95f639b/model/contrast/slots.py#L37 def forward(self, inputs_pe, inputs, weight=None, things=None): b, n, d = inputs_pe.shape slots = self.initial_slots.expand(b, -1, -1) k, v = self.to_k(inputs_pe), inputs_pe for _ in range(self.iters): q = slots # always taking the initial slots as q? dots = torch.einsum('bid,bjd->bij', q, k) * self.scale dots = torch.div(dots, torch.abs(dots).sum(2).expand_as(dots.permute([2, 0, 1])).permute([1, 2, 0])) * \ torch.abs(dots).sum(2).sum(1).expand_as(dots.permute([1, 2, 0])).permute([2, 0, 1]) attn = torch.sigmoid(dots) # print(torch.max(attn)) # dsfds() attn2 = attn / (attn.sum(dim=-1, keepdim=True) + self.eps) updates = torch.einsum('bjd,bij->bid', inputs, attn2) if self.vis: slots_vis_raw = attn.clone() vis(slots_vis_raw, "vis", self.args.feature_size, weight, things) return updates, attn
We did not strictly follow the origin slot attention paper. The iteration and weight for q are optional.
Based on your implementation, why does the slot attention iterate with the same q & k?