PaddlePaddle / PALM

a Fast, Flexible, Extensible and Easy-to-use NLP Large-scale Pretraining and Multi-task Learning Framework.
179 stars 30 forks source link

PALM梯度更新问题 #84

Open gobigrassland opened 4 years ago

gobigrassland commented 4 years ago

有三个问题请教: (1)有关多任务梯度更新的,我对代码理解是这样的:(multi_task/run.py) task1: 产生loss1,更新一次模型参数 task2:产生loss2,在上一次梯度更新基础上再更新一次 不断循环上述两个过程

(2)下面粘贴train.py中build_backward函数部分代码 我理解 param_list中就是存放模型参数值,那么updated_param = param - param_list[param.name] weight_decay optimizer.get_cur_learning_rate() 这个表达式,表示这个参数更新是这个参数值减去这个参数值乘以一个系数。这里就没有使用梯度值。参数更新,不是 w = w - alpha * grad_w吗?

def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=None):
    """
    Build backward computation graph and training strategy.

    Arguments:
        - optimizer: 
        - weight_decay: optional, default is None (disable weight decay).
        - use_ema: optional, default is False. The flag to control whether to apply Exponential Moving Average strategy on parameter updates.
        - ema_decay: optional, default is None. Only works with use_ema == True. Control decay rate of EMA strategy.

    """
    # build optimizer
    assert self._loss_var is not None and self._train_init_prog is not None, "train graph not foung! You should build_forward first."
    optimizer._set_prog(self._train_prog, self._train_init_prog)
    with fluid.program_guard(self._train_prog, self._train_init_prog):
        param_grads = optimizer._build()

            for param, grad in param_grads:
                if exclude_from_weight_decay(param.name):
                    continue
                with param.block.program._optimized_guard(
                    [param, grad]), fluid.framework.name_scope("weight_decay"):
                    updated_param = param - param_list[
                        param.name] * weight_decay * optimizer.get_cur_learning_rate()
                    fluid.layers.assign(output=param, input=updated_param)

        if use_ema:
            ema = fluid.optimizer.ExponentialMovingAverage(ema_decay)
            ema.update()

    self._exe.run(self._train_init_prog)

(3)PLAM是针对NLP的的多任务框架,有没有针对图像方面的多任务框架发布?

xyzhou-puck commented 4 years ago

你好,我是palm曾经的开发者,感谢你的关注,希望palm越做越好