megvii-research / TPS-CVPR2023

Apache License 2.0
39 stars 0 forks source link

Question about the TPS #3

Open King4819 opened 3 months ago

King4819 commented 3 months ago

` class TPS(M.Module):

def __init__(self, variant='dTPS'):
  ...

def forward(self, reserved, pruned, now_reserved_policy, now_pruned_policy):
   ...
    B, N, _ = reserved.shape[0], reserved.shape[1], reserved.shape[2]
    if self.variant == 'dTPS' and self.training:

        # during training, the tokens maintain a fixed shape
        # following dynamicViT, the pruned tokens' interaction with the class token will be removed in the multi-head attention layer
        cost_matrix = cal_cosine_similarity(
            reserved, None, mask_eye=-100)
        cost_matrix[F.broadcast_to(~now_reserved_policy.astype(
            'bool').reshape(B, 1, N), cost_matrix.shape)] = -100
        # the mask only keeps the interactions between pruned tokens and nearest reserved tokens in the current stage
        sim_th = cost_matrix.max(axis=2, keepdims=True)
        mask = (cost_matrix == sim_th).astype(
            'float32') * now_pruned_policy
        cost_matrix = (mask * cost_matrix)

        # transpose the dimension for batch matrix-multiplying
        mask = mask.transpose(0, 2, 1)
        cost_matrix = cost_matrix.transpose(0, 2, 1)
        numerator = F.exp(cost_matrix) * mask
        denominator = math.e + numerator.sum(axis=-1, keepdims=True)
        # fuse the host tokens with all matched pruned tokens
        reserved = reserved * (math.e / denominator) + \
            F.matmul(numerator / denominator, reserved)

    else:

        # during inference or training & infernce of the eTPS,
        # the pruned tokens and reserved tokens are splitted from the input tokens
        # and the pruned subset will be aggreagted into the matched reserved tokens dubbed as host tokens
        cost_matrix = cal_cosine_similarity(
            pruned, reserved, mask_eye=None)
        sim_th = cost_matrix.max(axis=2, keepdims=True)
        mask = (cost_matrix == sim_th).astype('float32')
        cost_matrix = mask * cost_matrix
        mask = mask.transpose(0, 2, 1)
        cost_matrix = cost_matrix.transpose(0, 2, 1)
        numerator = F.exp(cost_matrix) * mask
        denominator = math.e + numerator.sum(axis=-1, keepdims=True)
        reserved = reserved * (math.e / denominator) + \
            F.matmul(numerator / denominator, pruned)

    return reserved

`

if self.variant == 'dTPS' and self.training:

cost_matrix = cal_cosine_similarity( reserved, None, mask_eye=-100) ...

else: cost_matrix = cal_cosine_similarity( pruned, reserved, mask_eye=None) ,,,

seems like in dtps training stage, it calculate the cosine similarity between reserved token set itself instead of reserved token set and pruned token set, which is different from inference stage I'm wondering how does it work ? Thanks!!!

SiyuanWei commented 2 months ago

Sorry for the late reply, I didn’t check my GitHub in time. For dTPS, we use gumbel-softmax to help train learnable token scoring head ,which follows dynamicViT. So, in the training stage, the tokens can not be removed actually but with the policy mask , we can remove the pruned tokens' influence on the cost matrix in dTPS modules and the following attention operation