FMInference / DejaVu

268 stars 32 forks source link

Questions on sparse MLP implementation #2

Closed neuer93 closed 1 year ago

neuer93 commented 1 year ago

Hi,

Thanks for sharing the code online, and I really enjoy reading your paper.

When I read the paper, I have a confusion on the implementation of the sparse MLP in Decentralized_FM_alpha/modules/hf_opt_sparse_mlp_attention.py line 552 to line 558.

        hidden_states = self.fc1(hidden_states)
        if self.predictor != None:
            hidden_states = hidden_states * self._mask
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = torch.nn.functional.linear(
            hidden_states, self.fc2.weight.data.T, bias=self.fc2.bias.data
        )

Based on my understanding of the paper, the predictor should decide which neurons to be pruned before calculating fc1. Here, I find the mask is applied after fc1, which means that the computation on fc1 is not sparse. I am not sure if I miss anything here. Is this the right file (or function) of the sparse MLP with the sparsity predictor?

Thanks

YixinSong-e commented 1 year ago

I think this code is used for accuracy evaluation, not latency code.

neuer93 commented 1 year ago

I think this code is used for accuracy evaluation, not latency code.

Thanks for the explanation! Do you know if they have the code for latency evaluation, i.e., using sparsity predictor to reduce the end-to-end latency for LLM inference?

Thanks

Jimskns commented 8 months ago

Hi, @neuer93 Did you find out the latency code? I have the same confusion about how the sparsity predictor to reduce the latency. Thanks.