Outsider565 / LoRA-GA

155 stars 6 forks source link

逐层求梯度 #2

Open hhnqqq opened 4 months ago

hhnqqq commented 4 months ago

好像代码中并没有给出论文中提到的逐层求梯度的实现

Outsider565 commented 4 months ago

代码中实现的是附录里加gradient accumulation的算法,具体可以看get_record_gradient_hook和estimated_gradient这两个函数

Outsider565 commented 4 months ago

如果要实现非gradient accumulation的算法,可以稍微改改这个hook就可以实现,只是因为实际中这样会比先移到cpu上消耗更多的显存,因此在公开代码里没有选择这样的方式

hhnqqq commented 4 months ago

噢,我知道了,那个hook里清除grad实际上就实现逐层求了,刚刚没想明白哈哈哈

hhnqqq commented 4 months ago

不知道为什么,我复习出来的结果loss直接起飞了。以下是我对代码:

def get_record_gradient_hook(model):
    def record_gradient_hook(grad):
        for p in model.parameters():
            if p.requires_grad and p.grad is not None:
                if not hasattr (p, 'grad_stored'):
                    p.grad_stored = p.grad.cpu()
                else:
                    p.grad_stored += p.grad.cpu()
                p.grad = None
        return grad

    return record_gradient_hook

def lora_ga_reinit(
    model, 
    dataloader, 
    args,
    iters: int = 1
) -> Dict[str, List[torch.Tensor]]:
    r"""
    Estimate the gradient of the model on the given dataset
    """
    from common.dataset import RepeatingLoader
    from common.utils import to_device
    print_rank_0("--->Estimating gradient for lora ga.", rank=args.global_rank)
    model.train()
    hooks = []
    for param in model.parameters():
        param.requires_grad = True
        hook = param.register_hook(get_record_gradient_hook(model))
        hooks.append(hook)
    dataloader = RepeatingLoader(dataloader)
    for iter in range(iters):
        batch = to_device(next(dataloader), args.device)
        output = model(**batch)
        if args.huggingface:
            output.loss.backward()
        else:
            output[0].backward()
        get_record_gradient_hook(model)(None)
        for p in model.parameters():
            if p.grad is not None:
                p.grad = None
    for p in model.parameters():
        p.grad_stored /= iters
    for hook in hooks:
        hook.remove()
    torch.cuda.empty_cache()
    for name, module in model.named_modules():
        if isinstance(module, LinearWithLoRA):
            print_rank_0(f'--->Module {name} is reinitiating lora weight', args.global_rank)
            module.gradient_reinit()
    def gradient_reinit(self, 
                        direction:str='ArB2r', 
                        scale:str='gd', 
                        stable_gamma:int=16, 
                        scaling_factor:int=1):
        """
        Reinitialize the LoRA weights based on the gradient of the original weight matrix.

        This method implements the core functionality of LoRA-GA (Gradient-based Adaptation).
        It performs SVD on the weight gradient and uses the resulting matrices to update
        the LoRA weights (A and B).

        Args:
            direction (str): Determines how to select A and B from SVD results.
                Options: 'ArBr', 'A2rBr', 'ArB2r'. Default is 'ArB2r'.
            scale (str): Scaling method for the new LoRA weights.
                Options: 'gd', 'unit', 'stable', 'weightS'. Default is 'stable'.
            stable_gamma (float): Gamma parameter for 'stable' scaling. Default is 16.

        The method performs the following steps:
        1. Compute SVD of the weight gradient
        2. Select A and B matrices based on the 'direction' parameter
        3. Apply scaling to A and B based on the 'scale' parameter
        4. Update the LoRA weights (weight_a and weight_b)

        Note: This method assumes that the LinearWithLora layer has gradient. Please call this
        method in the first step of training(before the model.step() call, or the gradient will be cleared.)
        """

        if hasattr(self.weight, 'grad_stored'):
            # Perform SVD on the weight gradient
            # Weight stored gradient shape [out_feature, in_feature]
            U, S, V = torch.svd_lowrank(self.weight.grad_stored.float().cuda(), q=4 * self.lora_rank, niter=4)
            # U shape [out_feature, 4r] S shape [4r, 4r] V shape [in_feature, 4r]
            V = V.T

            # Determine A and B based on the direction parameter
            if direction == "ArBr":
                B = U[:, 0:2 * self.lora_rank:2]
                A = V[1:2 * self.lora_rank:2, :]
            elif direction == "A2rBr":
                # B shape [out_feature, r]
                B = U[:, :self.lora_rank]
                # A shape [r, in_feature]
                A = V[self.lora_rank:2 * self.lora_rank, :]
            elif direction == "ArB2r":
                B = U[:, self.lora_rank:2 * self.lora_rank]
                A = V[:self.lora_rank, :]
            else:
                raise ValueError(f"Unknown direction: {direction}")

            # Apply scaling to A and B based on the scale parameter
            if scale == "gd":
                A /= scaling_factor
                B /= scaling_factor
            elif scale == "stable":
                m, n = self.weight.grad_stored.shape 
                A = A * m**0.25 / stable_gamma**0.5
                B = B * m**0.25 / stable_gamma**0.5
            elif scale == "weightS":
                _, S, _ = torch.svd_lowrank(self.weight.float(), q=4 * self.lora_rank, niter=4)
                S /= scaling_factor
                avg_s = torch.sqrt(S[:self.lora_rank]).mean().to(A.device)
                A *= avg_s
                B *= avg_s
            elif scale != "unit":
                raise ValueError(f"Unknown scale: {scale}")
            del self.weight.grad_stored

            # Update the LoRA weights
            self.weight_a.data = A.contiguous().cuda()
            self.weight_b.data = B.contiguous().cuda()
hhnqqq commented 4 months ago

我用了一个比较小的数据集(前一段时间大火的弱智吧数据集)

Outsider565 commented 4 months ago

建议estimate gradient的bsz大一点,然后后面的lr小一点/scale用gd(stable scale在有的模型上好像就是不好使,和是否是全精度有关),我们这周会更新一个和peft兼容的版本,应该会改善一些数值稳定性

hhnqqq commented 4 months ago

[2024-07-17 15:44:17,361] [INFO] --->step=1, avg_loss=26.2500, avg_time=1.47s acc:0.0, mcc:0.0, remaining_time=2.0min 15.24s, remaining_steps:93 [2024-07-17 15:44:17,373] [INFO] --->Saving training config at 1th step in [2024-07-17 15:44:18,752] [INFO] --->step=2, avg_loss=26.0000, avg_time=1.38s acc:0.0, mcc:0.0, remaining_time=2.0min 5.58s, remaining_steps:92 [2024-07-17 15:44:20,108] [INFO] --->step=3, avg_loss=25.2500, avg_time=1.36s acc:0.0, mcc:0.0, remaining_time=2.0min 2.4s, remaining_steps:91 [2024-07-17 15:44:21,461] [INFO] --->step=4, avg_loss=25.8750, avg_time=1.35s acc:0.0, mcc:0.0, remaining_time=2.0min 0.15s, remaining_steps:90 [2024-07-17 15:44:22,813] [INFO] [logging.py:96:log_dist] [Rank 0] step=5, skipped=0, lr=[5.000000000000001e-12], mom=[(0.9, 0.98)] [2024-07-17 15:44:22,813] [INFO] [timer.py:260:stop] epoch=0/micro_step=5/global_step=5, RunningAvgSamplesPerSec=11.869049501371215, CurrSamplesPerSec=11.879489557431516, MemAllocated=4.93GB, MaxMemAllocated=20.15GB [2024-07-17 15:44:22,814] [INFO] --->step=5, avg_loss=25.1250, avg_time=1.35s acc:0.0, mcc:0.0, remaining_time=1.0min 58.8s, remaining_steps:89 [2024-07-17 15:44:24,170] [INFO] --->step=6, avg_loss=25.3750, avg_time=1.36s acc:0.0, mcc:0.0, remaining_time=1.0min 58.32s, remaining_steps:88

Outsider565 commented 4 months ago

如果是第一步loss就起飞的话,我建议做如下改动:

  1. 把estimate gradient的bsz调高,调到64或者128
  2. 把scale改成gd
  3. lora部分用fp32训(非常重要),主模型用bf16训 之前我们也观察到过这个现象,原因是我们初始化之后的BA值会很大,W-etaBA可能会数值不稳定,不过你要是不急的话,也可以等我们的下个版本,也许周六周日就能放出来。
hhnqqq commented 4 months ago

还有lora用fp32训的操作,我之前不知道哈哈哈 我这里主要图方便,我用的是iterabledataset,所以不方便弄一个bsz和训练时大小不一样的dataloader 我倒是不急,因为我不是弄peft的,复现出于兴趣 感谢回复

Outsider565 commented 4 months ago

lora用fp32训

这个主要是因为lora部分占的显存很小,但其实承载了非常多信息,所以这个时候用高精度数据格式,既能不多占多少显存,又能提升数值稳定性,这个想法很类似混合精度训练。感谢关注!

hhnqqq commented 4 months ago

确实,我用gd,scaling因子设为16就正常了(这个值我是乱设的),也许可以为后来者提供参考