Closed Windaway closed 2 years ago
def grad_scale(x, scale): y = x y_grad = x * scale return y.detach() - y_grad.detach() + y_grad
def round_pass(x): y = x.round() y_grad = x return y.detach() - y_grad.detach() + y_grad
这个看起来不会固定到一个确定范围内,grad scale后 clamp到-127-128之类?
def grad_scale(x, scale): y = x y_grad = x * scale return y.detach() - y_grad.detach() + y_grad
def round_pass(x): y = x.round() y_grad = x return y.detach() - y_grad.detach() + y_grad
这个看起来不会固定到一个确定范围内,grad scale后 clamp到-127-128之类?