iurada / px-ntk-pruning

Official repository of our work "Finding Lottery Tickets in Vision Models via Data-driven Spectral Foresight Pruning" accepted at CVPR 2024
https://iurada.github.io/PX
19 stars 4 forks source link

Cannot prun the LSTM block #4

Closed EnergeticChubby closed 3 weeks ago

EnergeticChubby commented 3 months ago

Issue Description

I encountered an issue when trying to add the LSTM_ class to layers.py while using the PX pruner. Specifically, the pruner fails to successfully obtain the scores.

Implementation Details

Below is the implementation of the LSTM_ class:

class LSTM_(nn.LSTM):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Register weight and bias masks
        self.register_buffer('weight_ih_l0_mask', torch.ones_like(self.weight_ih_l0))
        self.register_buffer('weight_hh_l0_mask', torch.ones_like(self.weight_hh_l0))
        if self.bias:
            self.register_buffer('bias_ih_l0_mask', torch.ones_like(self.bias_ih_l0))
            self.register_buffer('bias_hh_l0_mask', torch.ones_like(self.bias_hh_l0))

    def get_expected_cell_size(self, input: torch.Tensor, batch_sizes: Optional[torch.Tensor]) -> Tuple[int, int, int]:
        if batch_sizes is not None:
            mini_batch = int(batch_sizes[0])
        else:
            mini_batch = input.size(0) if self.batch_first else input.size(1)
        num_directions = 2 if self.bidirectional else 1
        expected_hidden_size = (self.num_layers * num_directions, mini_batch, self.hidden_size)
        return expected_hidden_size

    def check_forward_args(self, input: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor],
                           batch_sizes: Optional[torch.Tensor]):
        self.check_input(input, batch_sizes)
        self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
                               'Expected hidden[0] size {}, got {}')
        self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),
                               'Expected hidden[1] size {}, got {}')

    def permute_hidden(self, hx: Tuple[torch.Tensor, torch.Tensor], permutation: Optional[torch.Tensor]) -> Tuple[
        torch.Tensor, torch.Tensor]:
        if permutation is None:
            return hx
        return _apply_permutation(hx[0], permutation), _apply_permutation(hx[1], permutation)

    def forward(self, input: torch.Tensor, hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[
        torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        self._update_flat_weights()

        orig_input = input
        batch_sizes = None
        num_directions = 2 if self.bidirectional else 1
        real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size

        if isinstance(orig_input, PackedSequence):
            input, batch_sizes, sorted_indices, unsorted_indices = orig_input
            max_batch_size = batch_sizes[0]
            if hx is None:
                h_zeros = torch.zeros(self.num_layers * num_directions,
                                      max_batch_size, real_hidden_size,
                                      dtype=input.dtype, device=input.device)
                c_zeros = torch.zeros(self.num_layers * num_directions,
                                      max_batch_size, self.hidden_size,
                                      dtype=input.dtype, device=input.device)
                hx = (h_zeros, c_zeros)
            else:
                hx = self.permute_hidden(hx, sorted_indices)
        else:
            if input.dim() not in (2, 3):
                raise ValueError(f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead")
            is_batched = input.dim() == 3
            batch_dim = 0 if self.batch_first else 1
            if not is_batched:
                input = input.unsqueeze(batch_dim)
            max_batch_size = input.size(0) if self.batch_first else input.size(1)
            sorted_indices = None
            unsorted_indices = None
            if hx is None:
                h_zeros = torch.zeros(self.num_layers * num_directions,
                                      max_batch_size, real_hidden_size,
                                      dtype=input.dtype, device=input.device)
                c_zeros = torch.zeros(self.num_layers * num_directions,
                                      max_batch_size, self.hidden_size,
                                      dtype=input.dtype, device=input.device)
                hx = (h_zeros, c_zeros)
            else:
                if is_batched:
                    if (hx[0].dim() != 3 or hx[1].dim() != 3):
                        msg = ("For batched 3-D input, hx and cx should "
                               f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
                        raise RuntimeError(msg)
                else:
                    if hx[0].dim() != 2 or hx[1].dim() != 2:
                        msg = ("For unbatched 2-D input, hx and cx should "
                               f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors")
                        raise RuntimeError(msg)
                    hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1))
                self.check_forward_args(input, hx, batch_sizes)
                hx = self.permute_hidden(hx, sorted_indices)

        # Apply masks to weights and biases
        self.weight_ih_l0.data *= self.weight_ih_l0_mask
        self.weight_hh_l0.data *= self.weight_hh_l0_mask
        if self.bias:
            self.bias_ih_l0.data *= self.bias_ih_l0_mask
            self.bias_hh_l0.data *= self.bias_hh_l0_mask

        if batch_sizes is None:
            result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,
                              self.dropout, self.training, self.bidirectional, self.batch_first)
        else:
            result = _VF.lstm(input, batch_sizes, hx, self._flat_weights, self.bias,
                              self.num_layers, self.dropout, self.training, self.bidirectional)
        output = result[0]
        hidden = result[1:]

        if isinstance(orig_input, PackedSequence):
            output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
            return output_packed, self.permute_hidden(hidden, unsorted_indices)
        else:
            if not is_batched:
                output = output.squeeze(batch_dim)
                hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1))
            return output, self.permute_hidden(hidden, unsorted_indices)

The issue occurs when the _global_mask function in the PX pruner tries to get the scores, but fails.

PX Pruner _global_mask Function

def _global_mask(self, sparsity):
    # Threshold scores
    global_scores = torch.cat([torch.flatten(v) for v in self.scores.values()])
    k = int((1.0 - sparsity) * global_scores.numel())
    print(len(self.masked_parameters))
    for mask, param in self.masked_parameters:
        print(param.shape) # I tried to print the masked_parameters
    if not k < 1:
        threshold, _ = torch.kthvalue(global_scores, k)
        for mask, param in self.masked_parameters:
            score = self.scores[id(param)]
            zero = torch.tensor([0.]).to(mask.device)
            one = torch.tensor([1.]).to(mask.device)
            # print('mask', mask)
            # print('param', param)
            # print('score', score)
            mask.copy_(torch.where(score <= threshold, zero, one))

Model Structure

Here is the structure of the model:

Model Structure

Debugging Information

I attempted to print the masked_parameters and their shapes, which indicated that the parameters were not being handled correctly. Here are some debugging outputs:

Model parameters are shown in the following images:

Masked Parameters

Parameter Shapes

Request for Help

Could someone provide insights on why the PX pruner is failing to obtain the scores and suggest potential fixes? Any guidance or suggestions would be greatly appreciated. Thank you!

iurada commented 3 months ago

Hi!

Thanks for showing interest in our project and also for the thorough report of your issue with our codebase! My guess is that when iterating in the _global_mask function (i.e. when doing for mask, param in self.masked_parameters:) the LSTM parameters and masks are skipped, since those are not included by default in the associated generator (self.masked_parameters)

A fix could be going inside the lib/generator.py file and add inside the prunable function also the layers.LSTM_ module. Something like this:

def prunable(module, batchnorm, residual):
    r"""Returns boolean whether a module is prunable.
    """
    isprunable = isinstance(module, (layers.Linear, layers.Conv2d, layers.LSTM_)) # <<< Here I added your custom LSTM module
    if batchnorm:
        isprunable |= isinstance(module, (layers.BatchNorm1d, layers.BatchNorm2d))
    if residual:
        isprunable |= isinstance(module, (layers.Identity1d, layers.Identity2d))
    return isprunable

Hope this can help. If you need further assistance feel free to reach out at any time!