graphcore-research / unit-scaling

A library for unit scaling in PyTorch
https://graphcore-research.github.io/unit-scaling/
Apache License 2.0
102 stars 7 forks source link

[In Progress] Sharing Unit-Mamba Implementation #68

Open norikazu99 opened 2 months ago

norikazu99 commented 2 months ago

I've been working on a unit-scaled mamba block and wanted to share my work as well as ask a couple of questions. I used the https://github.com/johnma2006/mamba-minimal implementation as a skeleton.

Softplus:

def scaled_softplus(x, beta=1.0, threshold=20.0): 
    output_scale = 1/0.52103
    grad_scale = 1/0.20833444

    unit_softplus = scale_elementwise(
        F.softplus, output_scale, grad_scale, constraint='to_output_scale'
    )
    return unit_softplus(x, beta, threshold)

b, s, d = 32, 64, 128
x = torch.randn(b, s, d, requires_grad=True)

# assuming we're using default values

beta = 1.0
threshold = 20.0

x.grad = None
out = F.softplus(x, beta, threshold)
out.backward(torch.ones_like(out))
print('unscaled: ', x.grad.std().item(), out.std().item())
print('')

x.grad = None
out = scaled_softplus(x, beta, threshold)
out.backward(torch.ones_like(out))
print('scaled: ', x.grad.std().item(), out.std().item())
unscaled:  0.2083955705165863 0.5211385488510132
scaled:  0.3999684751033783 1.0002082586288452

SSM:

# init

d_state = 4
b, s, d = 512, 64, 64
d_inner = int(d*2)
dt_rank = math.ceil(d/16)
ssm_x_shape = (b, d_inner, d_state)

# modules

A_log = U.Parameter(torch.log(
    repeat(torch.arange(1, d_state + 1), 'n -> d n', d=d_inner)
), 'weight')
D = U.Parameter(torch.ones(d_inner), 'weight')
x_proj = U.Linear(d_inner, dt_rank + d_state * 2)
dt_proj = U.Linear(dt_rank, d_inner)

# data

x = torch.randn(b, s, d_inner, requires_grad=True)

part 1

# delta, A, B, C (forward + ssm)

a = -torch.exp(A_log.float())
A = scale_fwd(a, d_state**(-1/3.6))
delta, B, C = x_proj(x).split(split_size=[dt_rank, d_state, d_state], dim=-1)
delta2 = scaled_softplus(dt_proj(delta))

print('A: ', A.std().item(), 'B: ', B.std().item(), 'C: ', C.std().item())
print('delta: ', delta.std().item(), 'delta2: ', delta2.std().item())

#deltaA, deltaB_x (selective scan part1)

deltaA = UF.add(
    delta2.unsqueeze(3), 
    A.view(1, 1, A.size(0), -1)
)
delta_x = UF.add(delta2, x)
deltaB_x = UF.add(delta_x.unsqueeze(3), B.unsqueeze(2))

print('deltaA: ', deltaA.std().item(), 'delta_x: ', delta_x.std().item(), 'deltaB_x: ', deltaB_x.std().item())
A:  1.031583309173584 B:  1.002495288848877 C:  1.0055627822875977
delta:  0.9926937222480774 delta2:  0.9872873425483704
deltaA:  28.820053100585938 delta_x:  0.992746114730835 deltaB_x:  0.9943106174468994

All output scales seem to be properly scaled apart for deltaA due to the torch.exp operation. When the torch.exp operation is not used, deltaA is properly scaled since it uses UF.add. How would you recommend I handle this. Thank you very much for your time.

Note: unit_scaling as U, unit_scaling.functional as UF

norikazu99 commented 2 months ago

SSM part 2:

x = torch.randn_like(x)
ssm_x = torch.zeros(ssm_x_shape)
deltaA = torch.randn_like(deltaA)
deltaB_x = torch.randn_like(deltaB_x)
c_ = torch.randn_like(C)
d_ = torch.randn_like(D.float())

ys = []    
for i in range(s):
    ssm_x = UF.add(deltaA[:, i] * ssm_x , deltaB_x[:, i])
    y = UF.add(ssm_x, c_[:, i, :].unsqueeze(1))
    y1 = y.sum(dim=-1)
    y1 = scale_fwd(y1, y.size(-1)**-0.5)
    ys.append(y1)

print('y1: ', y1.std().item(), 'ssm_x: ', ssm_x.std().item())

ys = torch.stack(ys, dim=1)  # shape (b, l, d_in)
ys = UF.add(x * d_, ys)

print('ys: ', ys.std().item())
y1:  1.007655143737793 ssm_x:  1.0007102489471436
ys:  0.9854761362075806

This part seems to be properly scaled. Not using weighted add scaling and instead just using UF.add seems to do well for forward scale. Would using the weighted add rule for scale , described in unit_scaling paper 1, lead to better scaled outputs?

These are all the components that aren't already implemented in the unit-scaling library, that are needed for mamba. Thanks for making all of this possible. I will be checking out how well scales are for bwd before working on full model.