Closed ClashLuke closed 1 year ago
I can't figure out custom grad. The current best attempt is
import typing
import jax
from jax import lax, numpy as jnp
from src.backend import get_param, promote_to, with_context
from src.constants import ParallelAxes
from src.context import Context
from src.model.activate import activate_forward, activate_grad
def prenorm(fn: typing.Callable[[Context, jnp.ndarray], jnp.ndarray]):
def _fn(ctx: Context, inp: jnp.ndarray, *args) -> jnp.ndarray:
ctx = ctx.add_to_prefix("prenorm")
inp = scale_norm_act(ctx, inp, ctx.dims.features, act=False, init_mean=None)
out = fn(ctx, inp, *args)
return scale_norm_act(ctx, out, ctx.dims.features, act=False)
return _fn
def norm_forward(ctx: Context, src: jnp.ndarray, wgt: typing.Optional[jnp.ndarray] = None, psum: bool = False,
act: bool = True):
run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32)
original_dtype = src.dtype
src_fp64 = promote_to(src, run_type)
if psum:
src_fp64 = lax.psum(src_fp64, axis_name=ParallelAxes.model)
mean = src_fp64.mean(-1, keepdims=True)
mean_x = src_fp64 - mean
std = lax.abs(mean_x).sum(-1, keepdims=True)
norm_out = mean_x / std
out = norm_out * wgt.reshape((1,) * (src.ndim - 1) + (-1,))
if act:
out = activate_forward(out)
out = out.astype(original_dtype)
return out, norm_out, std
@with_context()
def scale_norm_act(ctx: Context, inp: jnp.ndarray, feature_dim: int, weight: typing.Optional[jnp.ndarray] = None,
psum: bool = False, act: bool = True, init_mean: typing.Optional[float] = 1) -> jnp.ndarray:
run_type = jnp.promote_types(ctx.model.computation_dtype, jnp.float32)
if weight is None:
if init_mean is None:
# init to 0 if checkpoint so, new layers get learned slowly (-> rezero but input)
# 1 otherwise to make sure model can learn
init_mean = float(not bool(ctx.training.checkpoint_load_path))
weight = get_param(ctx, "scale", [feature_dim], std=0, mean=init_mean, dtype=run_type,
lr_scale=ctx.optimizer.norm_scale)
if ctx.is_initializing:
return inp
@jax.custom_gradient
def _fn(src: jnp.ndarray, wgt: jnp.ndarray):
original_dtype = src.dtype
out, norm_out, std = norm_forward(ctx, src, wgt, psum, act)
def _grad(dy: jnp.ndarray) -> typing.Tuple[jnp.ndarray, jnp.ndarray]:
norm_out_fp64 = promote_to(norm_out, run_type)
reshaped_weight = wgt.reshape((1,) * (src.ndim - 1) + (-1,))
dy = promote_to(dy, run_type)
if act:
dy = dy * activate_grad(norm_out_fp64 * reshaped_weight)
x = src
d_wgt = (dy * norm_out_fp64).sum(list(range(src.ndim - 1))).reshape((-1,))
x_mean = x - x.mean(-1, keepdims=True)
x_mean_abs = lax.abs(x_mean)
x_mean_abs_prod = (x_mean_abs / feature_dim).prod(-1, keepdims=True)
x_mean_abs_size = x_mean_abs * feature_dim
x_div = x_mean / x_mean_abs
fac = (feature_dim - 1) * x_mean_abs_size.sum()
fac += x_mean * ((x_mean_abs_prod * x_div).sum() + (x.size - 2) * x_div / feature_dim ** feature_dim)
fac /= x_mean_abs.sum() ** 2
dy = fac * reshaped_weight * dy
if psum:
dy = lax.psum(dy, axis_name=ParallelAxes.model)
return dy.astype(original_dtype), d_wgt
return out, _grad
return _fn(inp, weight)
Based on WolframAlpha's calculated partial derivative:
I am dropping the issue for now, as the expected gain is too low to spend more days on it.
Others have reported increased stability with L1-BatchNorm so it might be worth a try for us as well