BaguaSys / bagua

Bagua Speeds up PyTorch
https://tutorials-8ro.pages.dev/
MIT License
872 stars 83 forks source link

Why does FusedOptimizer has a huge impact on model precision? #335

Closed ProHuper closed 2 years ago

ProHuper commented 2 years ago

I wrapped my custom optimizer with FusedOptimizer and the precision was way worse than that without FusedOptimizer. I think FusedOptimizer shouldn't be affecting the model precision. Or is there something wrong with my custom optimizer?

Here is the optimizer I use:

https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py

NOBLES5E commented 2 years ago

Thanks for opening the issue. FusedOptimizer is expected to give the same result as the original one for any torch optimizer. We'll investigate this case.

ProHuper commented 2 years ago

I reproduced this problem on a very simply example, fixed the model params and input, and got the same result. When using Lamb optimizer, the param update result in each step is different compared with that without FusedOptimizer. When using Adam optimizer, the param update result is the same. So I think it's probably related with the lamb optimizer.

import torch
from torch.nn.modules.loss import CrossEntropyLoss
from utils.LAMB_pt import LAMB
from bagua.torch_api.contrib import FusedOptimizer
import torch.nn as nn
import torch.optim

if __name__ == '__main__':
    input = torch.load('input.pt')
    label = torch.load('label.pt')
    model = torch.load('model.pt')

    # model = nn.Sequential(
    #     nn.Linear(10, 5),
    #     nn.Linear(5, 2),
    #     nn.Linear(2, 1),
    # )

    # optimizer = torch.optim.Adam(
    #     params=model.parameters(),
    #     lr=0.1,
    #     betas=(0.9, 0.999),
    #     eps=1e-06,
    #     weight_decay=0
    # )

    optimizer = LAMB(
        params=model.parameters(),
        lr=0.1,
        betas=(0.9, 0.999),
        eps=1e-06,
        weight_decay=0
    )

    model.to(0)
    optimizer = FusedOptimizer(optimizer, do_flatten=True)
    input = input.to(0)
    label = label.to(0)

    print('original:')
    print(optimizer.param_groups[0]['params'][0])

    for i in range(10):
        print('running new step')
        optimizer.zero_grad()
        output = model(input)
        loss = (output - label).pow(2).sum()
        loss.backward()

        optimizer.step()
        print(optimizer.param_groups[0]['params'][0])
NOBLES5E commented 2 years ago

Thanks! An example is super useful for us to debug. Could you help provide the pt files also?

ProHuper commented 2 years ago

pt_files.zip

ProHuper commented 2 years ago

It seems that lamb optimizer usesweight_norm that is related to the param itself, so if we group them into one big tensor, weight_norm will change. Any idea to do param fusion in this case?

  weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
  adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
  if group['weight_decay'] != 0:
      adam_step.add_(p.data, alpha=group['weight_decay'])
  adam_norm = adam_step.pow(2).sum().sqrt()

  if weight_norm == 0 or adam_norm == 0:
      trust_ratio = 1
  else:
      trust_ratio = weight_norm / adam_norm

  state['weight_norm'] = weight_norm
  state['adam_norm'] = adam_norm
  state['trust_ratio'] = trust_ratio

  p.data.add_(adam_step, alpha=-step_size * trust_ratio)
NOBLES5E commented 2 years ago

@wangraying is working on a less intrusive way to implement fused optimizer in https://github.com/BaguaSys/bagua/pull/207. Let's see whether that works in this case.

In the worst case we can still go with the easiest solution (that would be disabling fusing for weight_norm operations). But that's the last resort we don't want to do at this stage :smiley:

wangraying commented 2 years ago

@ProHuper We fixed fused optimizer in master, would you please check it again? Please let us know if it does not work as expected. BTW, the API changed a bit, FusedOptimizer is replaced by fuse_optimizer, here is the doc.

Thanks a lot.

ProHuper commented 2 years ago

Thanks, I tried the new API, and it still didn't work right. Also, I got the warning below:

WeChat4d3f2e036ac3f80c48f48eb3034e136f

I still think it's related to the implementation of the LAMB optimizer, it needs a weight_norm factor calculated from each param.

class Lamb(Optimizer):
    r"""Implements Lamb algorithm.
    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
    Arguments:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float, optional): learning rate (default: 1e-3)
        betas (Tuple[float, float], optional): coefficients used for computing
            running averages of gradient and its square (default: (0.9, 0.999))
        eps (float, optional): term added to the denominator to improve
            numerical stability (default: 1e-8)
        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
        adam (bool, optional): always use trust ratio = 1, which turns this into
            Adam. Useful for comparison purposes.
    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
        https://arxiv.org/abs/1904.00962
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
                 weight_decay=0, adam=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay)
        self.adam = adam
        super(Lamb, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Decay the first and second moment running average coefficient
                # m_t
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                # v_t
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Paper v3 does not use debiasing.
                # bias_correction1 = 1 - beta1 ** state['step']
                # bias_correction2 = 1 - beta2 ** state['step']
                # Apply bias to lr to avoid broadcast.
                step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1

                weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)

                adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
                if group['weight_decay'] != 0:
                    adam_step.add_(p.data, alpha=group['weight_decay'])

                adam_norm = adam_step.pow(2).sum().sqrt()
                if weight_norm == 0 or adam_norm == 0:
                    trust_ratio = 1
                else:
                    trust_ratio = weight_norm / adam_norm
                state['weight_norm'] = weight_norm
                state['adam_norm'] = adam_norm
                state['trust_ratio'] = trust_ratio
                if self.adam:
                    trust_ratio = 1

                p.data.add_(adam_step, alpha=-step_size * trust_ratio)

        return loss
wangraying commented 2 years ago

oh, that's embarrassing. I'll look into this problem soon.

wangraying commented 2 years ago

The fused optimizer makes an assumption that parameter and its state tensors should have the same data type and size (which is the case for all Pytorch official optimizers).

The Lamb optimizer in your case has two states weight_norm, adam_norm which does not satisfy this assumption.

However, we can easily make it compliant by changing the following two lines in the code you provided:

state['weight_norm'] = weight_norm
state['adam_norm'] = adam_norm

to

state['weight_norm'] = weight_norm.item()
state['adam_norm'] = adam_norm.item()

Note that by doing this the weight_normand adam_norm will be calculated based on the "fused tensors", which is not exactly the same as calculating them for original individual tensors.

Let us know if it works! Thanks

wangraying commented 2 years ago

I'll close this if no more problems raised.