Closed HCookie closed 3 weeks ago
Examples -------- >>> tensor = torch.randn(3, 4, 5) >>> scalars = ScaleTensor((0, torch.randn(3)), (1, torch.randn(4))) >>> scaled_tensor = scalars.scale(tensor) >>> scalars.get_scalar(tensor.ndim).shape torch.Size([3, 4, 1]) >>> scalars.add_scalar(-1, torch.randn(5)) >>> scalars.get_scalar(tensor.ndim).shape torch.Size([3, 4, 5])
ScaleTensor