Open kazunator opened 2 weeks ago
Hey Omar,
Pruning happens both at training (soft) and at generation (hard). You can take a look at the different cases here https://github.com/sanagno/adaptively_sparse_attention/blob/main/model.py#L155!
Oh gotcha. Reading through the code, this seems to be the place where the pruning is happening (or at least one of the places):
p_int_raw = (
(
torch.matmul(q_int, k_int.transpose(-1, -2))
/ math.sqrt(self.int_n_embd)
+ self.int_bias
)
.unsqueeze(1)
.unsqueeze(-1)
)
if self.sparsity_alpha == "inf":
# in eval mode we replace the alpha-sigmoid with the step function
p_int = (p_int_raw > 0)[..., 0]
else:
# Compare the raw drop scores with the values 0 to get the drop probabilities.
p_int_raw = torch.cat([p_int_raw, torch.zeros_like(p_int_raw)], dim=-1)
# Take only the first value of the entmax_bisect output, which is the probability of dropping.
p_int = entmax_bisect(p_int_raw.to(torch.float32), self.sparsity_alpha)[
..., 0
]
Is that correct? The entmax_bisect stuff is a bit obscure but I will take a look at that library to understand it
I've read your paper and found it amazing! Currently, I'm also working on an idea that requires context pruning so I thought I could check out your source code for inspiration, but it's a bit long and some parts just go over my head. Are you doing the pruning just at the time of the generation?