DeepLink-org / DIOPI

BSD 3-Clause "New" or "Revised" License
68 stars 34 forks source link

Implement some GB ops #1342

Closed DoorKickers closed 1 month ago

DoorKickers commented 3 months ago

Motivation and Context

Description

1.北大国标算子支持

完成了#1306 中7.2.10.3 ~ 7.2.15.4,7.2.4.3 ~ 7.2.11.1共24个待完善/开发的国标算子.

2.修复偶发的CI adam算子测试精度问题,将adam算子实现对齐到torch,参考https://github.com/pytorch/pytorch/blob/ba10259115b1c89292f3499cdba059ebb9ec6b4e/torch/optim/adam.py#L320

DIOPI中进行的修改如下:

diopiError_t diopiAdam(diopiContextHandle_t ctx, diopiTensorHandle_t param, diopiConstTensorHandle_t grad, diopiTensorHandle_t exp_avg,
                       diopiTensorHandle_t exp_avg_sq, diopiTensorHandle_t max_exp_avg_sq, float lr, float beta1, float beta2, float eps, float weight_decay,
                       int64_t step, bool amsgrad) {
    impl::aten::setCurStream(ctx);
    auto atParam = impl::aten::buildATen(param);
    auto atGrad = impl::aten::buildATen(grad);
    auto atExpAvg = impl::aten::buildATen(exp_avg);
    auto atExpAvgSq = impl::aten::buildATen(exp_avg_sq);
    auto atMaxExpAvgSq = impl::aten::buildATen(max_exp_avg_sq);

    auto grad_d = atGrad.data();
    if (weight_decay != 0) {
        grad_d = grad_d.add(atParam, weight_decay);
    }
    // atExpAvg.mul_(beta1).add_(grad_d, 1 - beta1);
    atExpAvg.lerp_(grad_d, 1 - beta1);
    atExpAvgSq.mul_(beta2).addcmul_(grad_d, grad_d.conj(), 1 - beta2);

    at::Tensor denom;
    auto bias_correction1 = 1 - pow(beta1, step);
    auto bias_correction2 = 1 - pow(beta2, step);
    if (amsgrad) {
        CALL_ATEN_CUDA_FUNC(maximum_out, atMaxExpAvgSq, atMaxExpAvgSq, atExpAvgSq);
        denom = atMaxExpAvgSq.sqrt().div_(sqrt(bias_correction2)).add_(eps);
    } else {
        denom = atExpAvgSq.sqrt().div_(sqrt(bias_correction2)).add_(eps);
    }
    auto stepSize = lr / bias_correction1;
    atParam.addcdiv_(atExpAvg, denom, -1 * stepSize);

    return diopiSuccess;
}

Use cases (Optional)

BC-breaking (Optional)

Checklist

Before PR:

After PR: