Open norikazu99 opened 2 months ago
SSM part 2:
x = torch.randn_like(x)
ssm_x = torch.zeros(ssm_x_shape)
deltaA = torch.randn_like(deltaA)
deltaB_x = torch.randn_like(deltaB_x)
c_ = torch.randn_like(C)
d_ = torch.randn_like(D.float())
ys = []
for i in range(s):
ssm_x = UF.add(deltaA[:, i] * ssm_x , deltaB_x[:, i])
y = UF.add(ssm_x, c_[:, i, :].unsqueeze(1))
y1 = y.sum(dim=-1)
y1 = scale_fwd(y1, y.size(-1)**-0.5)
ys.append(y1)
print('y1: ', y1.std().item(), 'ssm_x: ', ssm_x.std().item())
ys = torch.stack(ys, dim=1) # shape (b, l, d_in)
ys = UF.add(x * d_, ys)
print('ys: ', ys.std().item())
y1: 1.007655143737793 ssm_x: 1.0007102489471436
ys: 0.9854761362075806
This part seems to be properly scaled. Not using weighted add scaling and instead just using UF.add seems to do well for forward scale. Would using the weighted add rule for scale , described in unit_scaling paper 1, lead to better scaled outputs?
These are all the components that aren't already implemented in the unit-scaling library, that are needed for mamba. Thanks for making all of this possible. I will be checking out how well scales are for bwd before working on full model.
I've been working on a unit-scaled mamba block and wanted to share my work as well as ask a couple of questions. I used the https://github.com/johnma2006/mamba-minimal implementation as a skeleton.
Softplus:
SSM:
part 1
All output scales seem to be properly scaled apart for deltaA due to the torch.exp operation. When the torch.exp operation is not used, deltaA is properly scaled since it uses UF.add. How would you recommend I handle this. Thank you very much for your time.
Note: unit_scaling as U, unit_scaling.functional as UF