turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.22k stars 238 forks source link

Is this a good way to implement lora `module_to_save` for `embed_token` and `lm_head`? #271

Open fahadh4ilyas opened 5 months ago

fahadh4ilyas commented 5 months ago

I'm trying to implement lora that has extra module inside the adapter. Here is what I did to exllamav2 script

Inside lora.py from line 69 I add

                # Find target
                if key.endswith('lm_head.weight'):
                    if tensor.dtype == torch.bfloat16:
                        tensor = tensor.to(torch.float16)
                    elif tensor.dtype == torch.float32:
                        tensor = tensor.to(torch.float16)
                    target_module = self.model.modules_dict["lm_head"]
                    tensor = safe_move_tensor(tensor, target_module.device())
                    self.lm_head = torch.nn.Linear(target_module.in_features, tensor.shape[0], bias = False, device = "meta")
                    self.lm_head.weight = torch.nn.Parameter(tensor, requires_grad=False)
                    continue
                elif key.endswith('embed_tokens.weight'):
                    if tensor.dtype == torch.bfloat16:
                        tensor = tensor.to(torch.float16)
                    elif tensor.dtype == torch.float32:
                        tensor = tensor.to(torch.float16)
                    target_module = self.model.modules_dict["model.embed_tokens"]
                    tensor = safe_move_tensor(tensor, target_module.device())
                    self.embed_tokens = torch.nn.Embedding(tensor.shape[0], self.config.hidden_size, self.config.pad_token_id, device = "meta")
                    self.embed_tokens.weight = torch.nn.Parameter(tensor, requires_grad=False)
                    self.embed_tokens.weight[self.config.pad_token_id] = 0
                    continue

And inside model.py line 615 I change into

            if module.key == "model.embed_tokens" and loras is not None and loras[0].embed_tokens is not None:
                x = loras[0].embed_tokens(x)
            elif module.key == "lm_head" and loras is not None and loras[0].lm_head is not None:
                x = loras[0].lm_head(x)
            else:
                x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras)

But, the generating result is not what I expected. I also try it to first version of exllama and it works fine (maybe because first version of exllama did not quantize lm_head?).