Closed junphine closed 3 months ago
class ScaleNorm(nn.Module): def init(self, dim, eps = 1e-5): super().init() self.eps = eps self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))
@junphine hey, thank you for catching this! indeed the sign was not correct
it should be identical to rmsnorm except it is a single learned parameter rather than the model dimension
class ScaleNorm(nn.Module): def init(self, dim, eps = 1e-5): super().init() self.eps = eps self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5))