Open saravanabalagi opened 2 months ago
GLU
is currently not supported, so it's treated as an element-wise operation. However, since split
is supported, you can create your own GLU
operation like this:
class CustomGLU(nn.Module):
def __init__(self, dim=1):
super(CustomGLU, self).__init__()
self.dim = dim
def forward(self, x):
first_half, second_half = torch.split(x, x.size(self.dim)//2, dim=self.dim)
return first_half * torch.sigmoid(second_half)
If you don't use it extensively, the performance degradation shouldn't be significant.
Hi @janthmueller, thanks for the workaround, I tried that but the network comes back with only the last conv layer pruned. No dep group with first conv layer is being returned for pruning.
Hi @janthmueller, thanks for the workaround, I tried that but the network comes back with only the last conv layer pruned. No dep group with first conv layer is being returned for pruning.
After running the get_pruning_group
method within the prune_local
function of the MetaPruner
class, you might notice that the group containing the first layer appears to have double the number of indices. This likely occurs to prevent shape mismatch errors. However, with a pruning ratio of 0.5, attempting to prune the entire output of the first layer becomes impossible. This is because a group is ignored for pruning if all its filters or channels are pruned, resulting in nothing being pruned in your case.
To accommodate this scenario, it's crucial to apply a targeted adjustment before gathering the pruning_idxs
. Specifically, for groups involving the custom glu operation, a workaround involves halving the number of pruned indices (n_pruned
) for the affected group. This ensures that the pruning process correctly reflects the intended proportion.
To implement this adjustment, insert the following code snippet before collecting pruning_idxs
within both the prune_local
and prune_global
methods:
for dep, _ in group:
if isinstance(dep.target.module, ops._SplitOp):
n_pruned = n_pruned // 2
break
By incorporating this adjustment, the pruning mechanism can appropriately handle scenarios involving the custom glu operation, ensuring accurate pruning outcomes.
I think it might be best to fix this for all possible scenarios including a split, maybe similar to _is_attn_group
with a _is_split_group
check @VainF.
Great, thanks for the workaround and the explanation!
It would be great to have this merged such that the lib works directly on GLU!
Pruning a model with GLU results in an error when finding importance. GLU does not have any params but halves the input (in the given dimension). This is not accounted for during tracing, assigning indices, and finding importance.
Here's a minimal example with a simple model
I then prune using GroupNormPruner
This gives an index out of bounds error
Note that this error is raised when returning the group, so setting
interative=True
in pruner step does not help.