sanagno / adaptively_sparse_attention

17 stars 0 forks source link

How is the Context Pruned? #4

Open kazunator opened 2 weeks ago

kazunator commented 2 weeks ago

I've read your paper and found it amazing! Currently, I'm also working on an idea that requires context pruning so I thought I could check out your source code for inspiration, but it's a bit long and some parts just go over my head. Are you doing the pruning just at the time of the generation?

sanagno commented 2 weeks ago

Hey Omar,

Pruning happens both at training (soft) and at generation (hard). You can take a look at the different cases here https://github.com/sanagno/adaptively_sparse_attention/blob/main/model.py#L155!

kazunator commented 2 weeks ago

Oh gotcha. Reading through the code, this seems to be the place where the pruning is happening (or at least one of the places):

 p_int_raw = (
                (
                    torch.matmul(q_int, k_int.transpose(-1, -2))
                    / math.sqrt(self.int_n_embd)
                    + self.int_bias
                )
                .unsqueeze(1)
                .unsqueeze(-1)
            )

            if self.sparsity_alpha == "inf":
                # in eval mode we replace the alpha-sigmoid with the step function
                p_int = (p_int_raw > 0)[..., 0]
            else:
                # Compare the raw drop scores with the values 0 to get the drop probabilities.
                p_int_raw = torch.cat([p_int_raw, torch.zeros_like(p_int_raw)], dim=-1)

                # Take only the first value of the entmax_bisect output, which is the probability of dropping.
                p_int = entmax_bisect(p_int_raw.to(torch.float32), self.sparsity_alpha)[
                    ..., 0
                ]

Is that correct? The entmax_bisect stuff is a bit obscure but I will take a look at that library to understand it