sIncerass / powernorm

[ICML 2020] code for "PowerNorm: Rethinking Batch Normalization in Transformers" https://arxiv.org/abs/2003.07845
GNU General Public License v3.0
119 stars 17 forks source link

A few questions regarding fairseq/modules/norms/mask_powernorm.py #15

Open congwang093 opened 1 year ago

congwang093 commented 1 year ago

Hi, first of all thank you for your work. I've been spending some time trying to understand what is happening in this script fairseq/modules/norms/mask_powernorm.py but I've been having some trouble. can you please answer these questions?

  1. was 'GroupScaling1D' (starting at line 17) specific for the data or model architecture that was used for the experiments, but not necessarily a part of the general method for PowerNorm? based on my understanding, the input is supposed to be shaped (T tokens or instances, B batches, C channels). it seems to be a modified layer norm where each value in the input tensor is divided by the mean of the squared values across (each Groups of 4 channels for each Batch for each Token). I believe this was not mentioned in the paper.

  2. on these few lines here in the forward function of PowerFunction:

if current_iter < warmup_iters: runningphi.copy(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)

runningphi.copy(afwdrunning_phi + (1-afwd)var.mean(dim=0, keepdim=True))

since 'var' is (1,C,1,1), var.mean(dim=0,keepdim=True) is the same tensor as 'var'. was this intentional, or perhaps an artifact from an earlier version of the code? also did you mean to put an else statement here for 'runningphi.copy(afwdrunning_phi + (1-afwd)var.mean(dim=0, keepdim=True))'?

thank you, i'd very much appreciate your time

lumliolum commented 2 months ago

Hello,

I do have some more questions continuing on what was mentioned before

In the NormSelect function https://github.com/sIncerass/powernorm/blob/9ea6226a3203d5d6fcee07a5c6dec38ec6bc5e9f/fairseq/modules/norm_select.py#L12-L19

for batch norm we are using MaskSyncBatchNorm : version of Sync Batch Norm which is used because of multi-gpu training but for power norm I didn't see any SyncPowerNorm. Is it because PowerNorm doesn't need synchronized version ? As I understand that we need sync version if we are using batch statistics (this is why we don't have sync layer norm).

Also in appendix of the paper it is mentioned that for "PN-V", a synchronized version is used. If possible can you release that part of the code as well ?

lumliolum commented 2 months ago

The discussions around groupscaling are given here : #9, #8

Ice-Citron commented 2 months ago

thx for the clarification. not that I'm the original author, but I'm trying to implement this right now instead, if I'm successful I will share it here.

Ice-Citron commented 2 months ago

@lumliolum I have came up with this for now, I also checked it against the original mark_powernorm.py as much as I could. Can you help me check it too?

I ran the command "torchrun --standalone --nproc_per_node=4 test.py" to run this code using a setup with pytorch installed and 4x GPU, V100s in my case. I will double verify that teh code works soon. But, using the help of Claude 3.5 and GPT-4, I managed to come up with this makeshift solution. What I did was:

  1. Implemented a sync power norm version, and pasted the original version as well in the file
  2. Then, I basically pass in the same data into them, and initialised the same model (made sure its the exact same through randn and also torch.set_seed(42), and also torch.set_seed(42) each time before I initialised a random NN model which one is using syncPowerNorm (which is ran by all 4 GPUs) and the same for the original "MaskPowerNorm")
  3. Then I ran the file, and I basically also placed controlled_print everywhere to monitor their respective forward_pass and backward_pass every time, and then calculated the mean absolute difference between the .STD and .mean of the gradients of the 2 different NNs (1 using the original powernorm, and the other the version I made which has syncing)

The result is this, the mean absolute difference is 0, and the 2 different networks, 1 using the original and the other using syncPowerNorm which I implemented seems to have the exact same results, as seen in image below:

Screenshot 2024-08-28 at 2 28 37 AM

Success? I'm not sure yet. I will try and double check and get back to you.

# torchrun --standalone --nproc_per_node=4 test.py

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

# Global variable to control printing
PRINT_ALL_RANKS = True

def controlled_print(message):
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    if PRINT_ALL_RANKS or local_rank == 0:
        print(f"Rank {local_rank}: {message}")

class GroupScaling1D(nn.Module):
    def __init__(self, eps=1e-5, group_num=4):
        super(GroupScaling1D, self).__init__()
        self.eps = eps
        self.group_num = group_num

    def extra_repr(self):
        return f'eps={self.eps}, group={self.group_num}'

    def forward(self, input):
        T, B, C = input.shape[0], input.shape[1], input.shape[2]
        Cg = C // self.group_num
        gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
        moment2 = torch.repeat_interleave(torch.mean(gn_input * gn_input, dim=3, keepdim=True),
            repeats=Cg, dim=-1).contiguous().reshape(T, B, C)
        return input / torch.sqrt(moment2 + self.eps)

class PowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
                debug, warmup_iters, current_iter, mask_x):
        # Original PowerFunction forward code here
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        rmax = 1
        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)
        if current_iter <= warmup_iters:
            z = x /(var + eps).sqrt()
        else:
            z = x /(running_phi + eps).sqrt()

        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
        running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))
        y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)

        controlled_print(f"Original Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
        controlled_print(f"Original Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
        controlled_print(f"Original Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
        controlled_print(f"Original Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
        controlled_print(f"Original Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
        controlled_print(f"Original Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_tensors

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = (g - (1 - abkw) * ema_gz * z)
        ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))

        gx = 1. / torch.sqrt(var + eps) * approx_grad_g 
        grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
        grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)

        is_sync = hasattr(ctx, 'process_group')
        prefix = "Sync" if is_sync else "Original"

        controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
        controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
        controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
        controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")

        return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None

class MaskPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
                 affine=True, warmup_iters=10000, group_num=1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
        self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
        self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))
        self.afwd = alpha_fwd
        self.abkw = alpha_bkw
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
               'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)

    def forward(self, input, pad_mask=None, is_encoder=False):
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)

        if input.dim() == 4:  # N, C, H, W
            N, C, H, W = input.shape
            input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)

        T, B, C = input.shape
        input = self.gp(input)

        # construct the mask_input, size to be (BxL) x C: L is the real length here
        if pad_mask is None:
            mask_input = input.clone()
        else:
            # Transpose the bn_mask (B x T -> T x B)
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1)
            output = PowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps, \
                        self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input)

        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        # Reshape it.
        if shaped_input:
            output = output.squeeze(0)

        return output

class SyncPowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
                debug, warmup_iters, current_iter, mask_x, process_group, world_size):
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        ctx.process_group = process_group
        ctx.world_size = world_size

        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)

        # Synchronize var across GPUs
        if process_group is not None:
            dist.all_reduce(var, op=dist.ReduceOp.SUM, group=process_group)
            var /= world_size

        if current_iter <= warmup_iters:
            z = x / (var + eps).sqrt()
        else:
            z = x / (running_phi + eps).sqrt()

        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
        running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))
        y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)

        controlled_print(f"Sync Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
        controlled_print(f"Sync Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
        controlled_print(f"Sync Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
        controlled_print(f"Sync Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
        controlled_print(f"Sync Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
        controlled_print(f"Sync Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_tensors

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = (g - (1 - abkw) * ema_gz * z)
        ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))

        gx = 1. / torch.sqrt(var + eps) * approx_grad_g 
        grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
        grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)

        is_sync = hasattr(ctx, 'process_group')
        prefix = "Sync" if is_sync else "Original"

        if is_sync:
            process_group = ctx.process_group
            world_size = ctx.world_size
            dist.all_reduce(grad_weight, op=dist.ReduceOp.SUM, group=process_group)
            dist.all_reduce(grad_bias, op=dist.ReduceOp.SUM, group=process_group)
            grad_weight /= world_size
            grad_bias /= world_size

        controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
        controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
        controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
        controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")

        return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None

class SyncMaskPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
                 affine=True, warmup_iters=10000, group_num=1, process_group=None):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.process_group = process_group
        self.world_size = dist.get_world_size(process_group) if process_group else 1

        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
        self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
        self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))

        self.afwd = alpha_fwd
        self.abkw = alpha_bkw

        self.eps = eps
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
               'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)

    def forward(self, input, pad_mask=None, is_encoder=False):
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)

        if input.dim() == 4:  # N, C, H, W
            N, C, H, W = input.shape
            input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)

        T, B, C = input.shape
        input = self.gp(input)

        if pad_mask is None:
            mask_input = input.clone()
        else:
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1)
            output = SyncPowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps,
                        self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input,
                        self.process_group, self.world_size)
        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        if shaped_input:
            output = output.squeeze(0)

        return output

class TestModel(nn.Module):
    def __init__(self, norm_layer):
        super(TestModel, self).__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.norm = norm_layer(64)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.norm(self.conv(x)))

def run_test(local_rank, world_size):
    dist.init_process_group(backend="nccl", init_method='env://', world_size=world_size, rank=local_rank)

    # Set same seed for all processes
    torch.manual_seed(42)

    # Generate same data for all processes
    batch_size = 32
    torch.manual_seed(42)
    data = torch.randn(batch_size, 3, 64, 64).cuda(local_rank)

    # Ensure all processes have the same data
    dist.broadcast(data, src=0)

    if local_rank == 0:
        # Run original MaskPowerNorm on single GPU
        torch.manual_seed(42)
        model_original = TestModel(lambda num_features: MaskPowerNorm(num_features)).cuda(local_rank)
        out_original = model_original(data)
        loss_original = out_original.sum()
        loss_original.backward()

        controlled_print("Running original PowerNorm on single GPU")
        controlled_print(f"Original output mean: {out_original.mean().item()}")

    # Run SyncMaskPowerNorm on all GPUs
    torch.manual_seed(42)
    model_sync = TestModel(lambda num_features: SyncMaskPowerNorm(num_features, process_group=dist.group.WORLD)).cuda(local_rank)
    model_sync = DDP(model_sync, device_ids=[local_rank])

    controlled_print("Running SyncPowerNorm")
    out_sync = model_sync(data)
    controlled_print(f"Sync output mean on rank {local_rank}: {out_sync.mean().item()}")

    loss_sync = out_sync.sum()
    loss_sync.backward()

    if local_rank == 0:
        for (name_o, param_o), (name_s, param_s) in zip(model_original.named_parameters(), model_sync.named_parameters()):
            if param_o.grad is not None and param_s.grad is not None:
                grad_diff = (param_o.grad - param_s.grad).abs().mean().item()
                controlled_print(f"Gradient difference for {name_o}:")
                controlled_print(f"  Original - mean: {param_o.grad.mean().item()}, std: {param_o.grad.std().item()}")
                controlled_print(f"  Sync     - mean: {param_s.grad.mean().item()}, std: {param_s.grad.std().item()}")
                controlled_print(f"  Absolute difference: {grad_diff}")

    dist.barrier()
    dist.destroy_process_group()

def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    controlled_print(f"Running on rank {local_rank} of {world_size}")
    run_test(local_rank, world_size)

if __name__ == "__main__":
    main()
lumliolum commented 2 months ago

Hey @Ice-Citron

Thanks for sharing the code for SyncPowerNorm. Currently I don't have acess to the server (with multiple GPU's) to run your code. I will get it back after some weeks I guess (If possible, I will update you then)

I don't have much knowledge on how distributed system and code works (apologies in advance if I wrote something wrong) but anyway I had a look. Some queries I have is

Ice-Citron commented 1 month ago

@lumliolum Hi. Sorry that I have been late to this. Have been very busy with other school stuff and projects. But I have 2 hours now, and will try and answer your questions, by trying to verify and figure out whats going on with my code, to make sure that it works.

Ice-Citron commented 1 month ago

@lumliolum Give me a bit longer. Just managed to understood how the maths etc. works. Only realised just now that this normalisation layer is meant for ViTs instead of language transformers. lol

Ice-Citron commented 1 month ago

@lumliolum Actually yeah, good point. When I rushed out the code initially. I was just trying to get something running as fast as possible. Didn't realised about that, I will try and look at whether is all_reduce neccesary for running_ema.

Whilst, for the reason why gx isn't synchronised, this is what GPT-4 said, which makes a lot of sense.


The reason gx (the gradient of the loss with respect to the input ( x )) does not require synchronization through all_reduce or similar operations in a distributed training context like DDP (Distributed Data Parallel) primarily relates to how and where it is used in neural network training.

Understanding the Use of gx in Backpropagation

  1. Local Relevance:

    • gx represents the gradient of the network's loss with respect to its inputs at a particular layer. This gradient is used locally by each GPU to compute gradients for the weights and biases that are directly connected to these inputs within the same layer or previous layers.
    • Since each GPU processes a different subset of the data (mini-batch), gx computed on one GPU is specifically relevant to the forward pass computations and the immediate backpropagation calculations on that same GPU. There is no need to share or synchronize gx across GPUs because it does not directly influence the parameters being updated in other GPUs.
  2. Parameter Updates vs. Input Gradients:

    • In DDP, the primary concern is to synchronize the gradients of the parameters (weights and biases) across all GPUs before they are updated. This synchronization ensures that each GPU updates its model parameters based on the complete information from the entire distributed dataset, thus maintaining consistency across all replicas of the model.
    • gx, on the other hand, is used to propagate error gradients back through the network to update parameters local to each layer. Each instance of backpropagation uses gx from the subsequent layer to calculate the gradients for parameters in the current layer, but these gradients do not need to be shared across GPUs because each GPU independently handles a portion of the data.
  3. No Global Impact:

    • gx affects only the local computation within a GPU and is used to calculate gradients of parameters that will subsequently be synchronized. Since the parameters are the elements that ultimately need to be consistent across all GPUs, only their gradients require synchronization.
    • If gx were part of a computation that influenced model parameters globally (across all GPUs), then it would indeed need to be synchronized. However, in standard neural network operations, this is not the case.

Practical Implications

In a practical scenario, consider a multi-layer neural network training across multiple GPUs:

Conclusion

gx not requiring synchronization across GPUs allows each part of the model to be trained as efficiently and quickly as possible on its batch of data, relying on the synchronization of parameter gradients to ensure overall model consistency. This approach optimizes network bandwidth and computation by avoiding unnecessary data transfers and maintains the independence of local computations, which is a key advantage in distributed training setups.


question: is it wrong is gx is synchronised?


Synchronizing gx (the gradient of the loss with respect to the input at a particular layer) across GPUs in a distributed training setup isn't typically necessary, but doing so wouldn't be inherently "wrong"—rather, it would be inefficient and largely redundant. Let's explore why this is generally avoided and what implications it might have if done:

1. Redundancy and Efficiency

2. Practical Impact of Synchronization

3. Conceptual Clarity

4. Theoretical Considerations

Conclusion

While not "wrong," synchronizing gx is not standard practice and does not align with the principles of efficient, effective distributed training. Keeping gx local optimizes computational resources and aligns with the foundational concepts of neural network training, where each part of the network learns based on the data it processes directly. The key to efficient distributed training is to minimize overhead while ensuring global consistency where it impacts the model—primarily at the level of model parameters, not intermediate gradients like gx.

Ice-Citron commented 1 month ago

@lumliolum I'm looking into running_ema now. Have a feeling this needs to be synchronised, but let's see.

Ice-Citron commented 1 month ago

@lumliolum Yep your correct. Here's the final code instead. Please do let me know if there's any logic error still. I had tried to double check with Claude and GPT4 already and doesn't seem like any error. But let me know if there's anything wrong. Thx

# torchrun --standalone --nproc_per_node=4 test.py

import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os

# Global variable to control printing
PRINT_ALL_RANKS = True

def controlled_print(message):
    local_rank = int(os.environ.get("LOCAL_RANK", "0"))
    if PRINT_ALL_RANKS or local_rank == 0:
        print(f"Rank {local_rank}: {message}")

class GroupScaling1D(nn.Module):
    def __init__(self, eps=1e-5, group_num=4):
        super(GroupScaling1D, self).__init__()
        self.eps = eps
        self.group_num = group_num

    def extra_repr(self):
        return f'eps={self.eps}, group={self.group_num}'

    def forward(self, input):
        T, B, C = input.shape[0], input.shape[1], input.shape[2]
        Cg = C // self.group_num
        gn_input = input.contiguous().reshape(T, B, self.group_num, Cg)
        moment2 = torch.repeat_interleave(torch.mean(gn_input * gn_input, dim=3, keepdim=True),
            repeats=Cg, dim=-1).contiguous().reshape(T, B, C)
        return input / torch.sqrt(moment2 + self.eps)

class PowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
                debug, warmup_iters, current_iter, mask_x):
        # Original PowerFunction forward code here
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        rmax = 1
        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)
        if current_iter <= warmup_iters:
            z = x /(var + eps).sqrt() # dividing by sqrt(varience + eps), which is same as dividing by standard deviation
        else:
            z = x /(running_phi + eps).sqrt() # same thing as above, but using running stats instead

        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter) # cumulative moving average
        running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True)) # exponential moving average
        y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)

        controlled_print(f"Original Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
        controlled_print(f"Original Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
        controlled_print(f"Original Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
        controlled_print(f"Original Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
        controlled_print(f"Original Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
        controlled_print(f"Original Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_tensors

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = (g - (1 - abkw) * ema_gz * z) # approx function seems to just be using CTX stored tensors
        ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))

        gx = 1. / torch.sqrt(var + eps) * approx_grad_g # REFER TO NOTES REGARDING BACKPROP DERIVATIVE EQUATION
        grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
        grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)

        is_sync = hasattr(ctx, 'process_group')
        prefix = "Sync" if is_sync else "Original"

        controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
        controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
        controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
        controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")

        return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None

class MaskPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
                 affine=True, warmup_iters=10000, group_num=1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
        self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
        self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))
        self.afwd = alpha_fwd
        self.abkw = alpha_bkw
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
               'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)

    def forward(self, input, pad_mask=None, is_encoder=False):
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)

        if input.dim() == 4:  # N, C, H, W
            N, C, H, W = input.shape
            input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)

        T, B, C = input.shape
        input = self.gp(input)

        # construct the mask_input, size to be (BxL) x C: L is the real length here
        if pad_mask is None:
            mask_input = input.clone()
        else:
            # Transpose the bn_mask (B x T -> T x B)
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1)
            output = PowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps, \
                        self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input)

        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        # Reshape it.
        if shaped_input:
            output = output.squeeze(0)

        return output

class SyncPowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz,
                debug, warmup_iters, current_iter, mask_x, process_group, world_size):
        ctx.eps = eps
        ctx.debug = debug
        current_iter = current_iter.item()
        ctx.current_iter = current_iter
        ctx.warmup_iters = warmup_iters
        ctx.abkw = abkw
        ctx.process_group = process_group
        ctx.world_size = world_size

        N, C, H, W = x.size()
        x2 = (mask_x * mask_x).mean(dim=0)

        var = x2.reshape(1, C, 1, 1)

        # Synchronize var across GPUs
        if process_group is not None:
            dist.all_reduce(var, op=dist.ReduceOp.AVG, group=process_group) # no need to divide, because already averaging

        if current_iter <= warmup_iters:
            z = x / (var + eps).sqrt()
        else:
            z = x / (running_phi + eps).sqrt()

        y = z
        ctx.save_for_backward(z, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var.mean(dim=0, keepdim=True)/current_iter)
        running_phi.copy_(afwd*running_phi + (1-afwd)*var.mean(dim=0, keepdim=True))

        # Synchronize running_phi across all processes
        if process_group is not None:
            torch.distributed.all_reduce(running_phi, op=torch.distributed.ReduceOp.AVG, group=process_group)

        y = weight.reshape(1,C,1,1) * y + bias.reshape(1,C,1,1)

        controlled_print(f"Sync Forward - Input mean: {x.mean().item()}, std: {x.std().item()}")
        controlled_print(f"Sync Forward - Weight mean: {weight.mean().item()}, std: {weight.std().item()}")
        controlled_print(f"Sync Forward - Bias mean: {bias.mean().item()}, std: {bias.std().item()}")
        controlled_print(f"Sync Forward - Running phi mean: {running_phi.mean().item()}, std: {running_phi.std().item()}")
        controlled_print(f"Sync Forward - Var mean: {var.mean().item()}, std: {var.std().item()}")
        controlled_print(f"Sync Forward - Output mean: {y.mean().item()}, std: {y.std().item()}")

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        debug = ctx.debug
        current_iter = ctx.current_iter
        warmup_iters = ctx.warmup_iters
        abkw = ctx.abkw

        N, C, H, W = grad_output.size()
        z, var, weight, ema_gz = ctx.saved_tensors

        y = z
        g = grad_output * weight.reshape(1, C, 1, 1)
        g = g * 1

        gz = (g * z).mean(dim=3).mean(dim=2).mean(dim=0)

        approx_grad_g = (g - (1 - abkw) * ema_gz * z)
        ema_gz.add_((approx_grad_g * z).mean(dim=3, keepdim=True).mean(dim=2, keepdim=True).mean(dim=0, keepdim=True))

        if ctx.process_group is not None:
            dist.all_reduce(ema_gz, op=dist.ReduceOp.AVG, group=ctx.process_group)

        gx = 1. / torch.sqrt(var + eps) * approx_grad_g 
        grad_weight = (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0)
        grad_bias = grad_output.sum(dim=3).sum(dim=2).sum(dim=0)

        is_sync = hasattr(ctx, 'process_group')
        prefix = "Sync" if is_sync else "Original"

        if is_sync:
            process_group = ctx.process_group
            world_size = ctx.world_size
            dist.all_reduce(grad_weight, op=dist.ReduceOp.AVG, group=process_group)
            dist.all_reduce(grad_bias, op=dist.ReduceOp.AVG, group=process_group)

        controlled_print(f"{prefix} Backward - Grad output mean: {grad_output.mean().item()}, std: {grad_output.std().item()}")
        controlled_print(f"{prefix} Backward - Grad input mean: {gx.mean().item()}, std: {gx.std().item()}")
        controlled_print(f"{prefix} Backward - Grad weight mean: {grad_weight.mean().item()}, std: {grad_weight.std().item()}")
        controlled_print(f"{prefix} Backward - Grad bias mean: {grad_bias.mean().item()}, std: {grad_bias.std().item()}")

        return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None, None, None, None

class SyncMaskPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5, alpha_fwd=0.9, alpha_bkw=0.9,
                 affine=True, warmup_iters=10000, group_num=1, process_group=None):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.process_group = process_group
        self.world_size = dist.get_world_size(process_group) if process_group else 1

        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))

        self.register_buffer('running_phi', torch.ones(1,num_features,1,1))
        self.register_buffer('ema_gz', torch.zeros(1,num_features,1,1))
        self.register_buffer('iters', torch.zeros(1).type(torch.LongTensor))

        # Synchronize buffers
        dist.all_reduce(self.running_phi, op=dist.ReduceOp.AVG, group=self.process_group)
        dist.all_reduce(self.ema_gz, op=dist.ReduceOp.AVG, group=self.process_group)
        # dist.all_reduce(self.iters, op=dist.ReduceOp.MAX, group=self.process_group)

        self.afwd = alpha_fwd
        self.abkw = alpha_bkw

        self.eps = eps
        self.debug = False
        self.warmup_iters = warmup_iters
        self.gp = GroupScaling1D(group_num=group_num)
        self.group_num = group_num

    def extra_repr(self):
        return '{num_features}, eps={eps}, alpha_fwd={afwd}, alpha_bkw={abkw}, ' \
               'affine={affine}, warmup={warmup_iters}, group_num={group_num}'.format(**self.__dict__)

    def forward(self, input, pad_mask=None, is_encoder=False):
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)

        if input.dim() == 4:  # N, C, H, W
            N, C, H, W = input.shape
            input = input.permute(2, 3, 0, 1).contiguous().view(H*W, N, C)

        T, B, C = input.shape
        input = self.gp(input)

        if pad_mask is None:
            mask_input = input.clone()
        else:
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        if self.training:
            self.iters.copy_(self.iters + 1) # maybe consider syncing this, but unlikely
            output = SyncPowerFunction.apply(input, self.weight, self.bias, self.running_phi, self.eps,
                        self.afwd, self.abkw, self.ema_gz, self.debug, self.warmup_iters, self.iters, mask_input,
                        self.process_group, self.world_size)
        else:
            N, C, H, W = input.size()
            var = self.running_phi
            output = input / (var + self.eps).sqrt()
            output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()
        if shaped_input:
            output = output.squeeze(0)

        return output

class TestModel(nn.Module):
    def __init__(self, norm_layer):
        super(TestModel, self).__init__()
        self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.norm = norm_layer(64)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.relu(self.norm(self.conv(x)))

def run_test(local_rank, world_size):
    dist.init_process_group(backend="nccl", init_method='env://', world_size=world_size, rank=local_rank)

    # Set same seed for all processes
    torch.manual_seed(42)

    # Generate same data for all processes
    batch_size = 32
    torch.manual_seed(42)
    data = torch.randn(batch_size, 3, 64, 64).cuda(local_rank)

    # Ensure all processes have the same data
    dist.broadcast(data, src=0)

    if local_rank == 0:
        # Run original MaskPowerNorm on single GPU
        torch.manual_seed(42)
        model_original = TestModel(lambda num_features: MaskPowerNorm(num_features)).cuda(local_rank)
        out_original = model_original(data)
        loss_original = out_original.sum()
        loss_original.backward()

        controlled_print("Running original PowerNorm on single GPU")
        controlled_print(f"Original output mean: {out_original.mean().item()}")

    # Run SyncMaskPowerNorm on all GPUs
    torch.manual_seed(42)
    model_sync = TestModel(lambda num_features: SyncMaskPowerNorm(num_features, process_group=dist.group.WORLD)).cuda(local_rank)
    model_sync = DDP(model_sync, device_ids=[local_rank])

    controlled_print("Running SyncPowerNorm")
    out_sync = model_sync(data)
    controlled_print(f"Sync output mean on rank {local_rank}: {out_sync.mean().item()}")

    loss_sync = out_sync.sum()
    loss_sync.backward()

    if local_rank == 0:
        for (name_o, param_o), (name_s, param_s) in zip(model_original.named_parameters(), model_sync.named_parameters()):
            if param_o.grad is not None and param_s.grad is not None:
                grad_diff = (param_o.grad - param_s.grad).abs().mean().item()
                controlled_print(f"Gradient difference for {name_o}:")
                controlled_print(f"  Original - mean: {param_o.grad.mean().item()}, std: {param_o.grad.std().item()}")
                controlled_print(f"  Sync     - mean: {param_s.grad.mean().item()}, std: {param_s.grad.std().item()}")
                controlled_print(f"  Absolute difference: {grad_diff}")

    dist.barrier()
    dist.destroy_process_group()

def main():
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    controlled_print(f"Running on rank {local_rank} of {world_size}")
    run_test(local_rank, world_size)

if __name__ == "__main__":
    main()
lumliolum commented 1 month ago

Only realised just now that this normalisation layer is meant for ViTs instead of language transformers

But in the paper, they were using this norm layer for machine translation and language modeling task.

Here's the final code instead

You didn't paste total final code. Only pasted the forward of nn.Module so can you post autograd.Function

Ice-Citron commented 1 month ago

@lumliolum ah sorry, there you go. Please try and check my full code instead. I recommend Tensordock if you wanna get started. Something like 4x Nvidia A4000 already works, which costs 0.5 USD per hour.

But in the paper, they were using this norm layer for machine translation and language modeling task.

Ah I see. I am just trying to make sure that I'm able to convert my .pth transformer model to huggingface, then I will personally start training my models with a node of 8x H100, soon.

Ice-Citron commented 1 month ago

@lumliolum Hi. just wanna check in. Any issues hence forth?

Ice-Citron commented 1 month ago

@lumliolum ive just ran it with my GPT2-124 implementation with Edu_fineweb dataset. Worked fine for the first 10k steps, then gradients exploded and the loss curve just went hay-wire.

Double checked and nothing is wrong afaik. Costed me 70 USD in H100s for nothing, but lol.

But it shows, in fact that the running statistics normalistion technique is, no offense but, a load of shit.

It probably worked for the paper itself because they were using a dataset which didn't really had much varience and was very simillar in itself.

Even if it worked, and that the global gradient norm didn't exploded. This really still wouldn't be this helpful. It caused a 15%-20% speed decrease in tokens/sec. And your just much better off using something much simpler like LayerNorm or RMSN.

I think this paper and entire repo is just one of those cases where the author itself is just in a way, overfitting to his own dataset, and had hence caused this to happen.

Its a real shame tbh, that he didn't released the code for syncpowernorm, and I had to waste all this time to implement something which ultimately failed. Ah well, plenty is learnt personally, but 70 USD just went poof.

Screenshot 2024-09-16 at 4 03 46 PM

Screenshot 2024-09-16 at 4 04 04 PM
lumliolum commented 1 month ago

Hello

Sorry, I couldn't reply as I was very busy with exams. I also didn't get the server access so couldn't run the code. I also agree that they should have released the sync power norm. This is what I will do if I get the server access.

Whichever gives the better result I will take it. Also for my research problem, I want to train these LLM without layernorm (this is my main objectivr). If nothing works here, I will just go back to batchnorm.

Ice-Citron commented 1 month ago

@lumliolum

Here's my code, which is based on Andrej Karparthy's nanoGPT.

https://huggingface.co/shng2025/GPT-Valkyrie_PN-124m/blob/denim-lake-75/train_gpt2.py

import os
import math
import time
import inspect
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from hellaswag import render_example, iterate_examples

from argparse import Namespace

import wandb

import tiktoken
import numpy as np

from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

from huggingface_hub import Repository, create_branch
import logging
import glob

# -----------------------------------------------------------------------------
os.environ['NUMEXPR_MAX_THREADS'] = '96'

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        # nh is "number of heads", hs is "head size", and C (number of channels) = nh * hs
        # e.g. in GPT-2 (124M), n_head=12, hs=64, so nh*hs=C=768 channels in the Transformer
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # flash attention
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)
        self.c_proj.NANOGPT_SCALE_INIT = 1

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

# -----------------------------------------------------------------------------

# given up
class GroupScaling1D(nn.Module):
    def __init__(self, eps=1e-5, group_num=1):
        super(GroupScaling1D, self).__init__()
        self.eps = eps
        self.group_num = group_num

    def extra_repr(self):
        return f'eps={self.eps}, group={self.group_num}'

    def forward(self, input):
        B, T, C = input.shape
        if self.group_num == 1:
            moment2 = torch.mean(torch.square(input), dim=-1, keepdim=True)
        else:
            Cg = C // self.group_num
            input = input.view(B, T, self.group_num, Cg)
            moment2 = torch.mean(input * input, dim=-1, keepdim=True)
            moment2 = moment2.repeat(1, 1, 1, Cg).view(B, T, C)
        return input / torch.sqrt(moment2 + self.eps)

class SyncPowerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias, running_phi, eps, afwd, abkw, ema_gz, warmup_iters, current_iter, process_group, affine):
        ctx.affine = affine
        ctx.eps = eps
        current_iter = current_iter.item()
        ctx.process_group = process_group
        ctx.abkw = abkw

        B, T, C = x.size()
        x2 = torch.mean(torch.square(x), dim=(0, 1))
        var = x2.view(1, 1, C)

        if current_iter <= warmup_iters:
            y = x * torch.rsqrt(var + eps) # Shape: (B, T, C)
        else:
            y = x * torch.rsqrt(running_phi + eps) # Shape: (B, T, C)

        ctx.save_for_backward(y, var, weight, ema_gz)

        if current_iter < warmup_iters:
            running_phi.copy_(running_phi * (current_iter-1)/current_iter + var/current_iter) # MISTAKE? UNSURE <-- I think its correct, because we want to value for every feature inside dim=-1 calculated across the batch
        running_phi.copy_(afwd*running_phi + (1-afwd)*var) # MISTAKE? UNSURE

        if process_group is not None: # and (current_iter % 100 or current_iter == args.max_steps): # experimental, to try and reduce time spent accessing memory
            torch.distributed.all_reduce(running_phi, op=torch.distributed.ReduceOp.AVG, group=process_group)

        if affine:
            y = weight.view(1, 1, C) * y + bias.view(1, 1, C)

        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps
        abkw = ctx.abkw

        y, var, weight, ema_gz = ctx.saved_tensors

        if ctx.affine:
            g = grad_output * weight.view(1, 1, -1)
        else:
            g = grad_output

        approx_grad_g = g - (1 - abkw) * ema_gz * y
        ema_gz.add_(torch.mean(approx_grad_g * y, dim=(0, 1), keepdim=True))

        if ctx.process_group is not None: # and (current_iter % 100 or current_iter == args.max_steps): # experimental, to try and reduce time spent accessing memory
            dist.all_reduce(ema_gz, op=dist.ReduceOp.AVG, group=ctx.process_group)

        gx = torch.rsqrt(var + eps) * approx_grad_g

        if ctx.affine:
            grad_weight = torch.sum(grad_output * y, dim=(0, 1))
            grad_bias = torch.sum(grad_output, dim=(0, 1))
            if ctx.process_group is not None:
                dist.all_reduce(grad_weight, op=dist.ReduceOp.AVG, group=ctx.process_group)
                dist.all_reduce(grad_bias, op=dist.ReduceOp.AVG, group=ctx.process_group)
        else:
            grad_weight = None
            grad_bias = None

        return gx, grad_weight, grad_bias, None, None, None, None, None, None, None, None, None

class SyncPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-3, alpha_fwd=0.9, alpha_bkw=0.9,
                 affine=True, warmup_iters=715, process_group=None, group_num=1):
        super().__init__()

        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.process_group = process_group

        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))
            self.bias = nn.Parameter(torch.zeros(num_features))
        else:
            self.weight = None
            self.bias = None

        self.register_buffer('running_phi', torch.ones(1, 1, num_features))     # I THINK THIS IS LIKELY THE ISSUE, ITS NOT MEANT TO BE A PARAMETER BUT WE ARE TREATING IT LIKE ONE???
        self.register_buffer('ema_gz', torch.zeros(1, 1, num_features))
        self.register_buffer('iters', torch.ones(1).type(torch.LongTensor))

        self.afwd = alpha_fwd
        self.abkw = alpha_bkw

        self.warmup_iters = warmup_iters
        self.grad_accum_steps = args.gradient_accumulation_steps
        self.accum_count = 0
        self.group_scaling = GroupScaling1D(eps=eps, group_num=group_num)

    def extra_repr(self):
        return f'{self.num_features}, eps={self.eps}, alpha_fwd={self.afwd}, alpha_bkw={self.abkw}, ' \
               f'affine={self.affine}, warmup={self.warmup_iters}'

    def forward(self, input):
        B, T, C = input.size()
        assert C == self.num_features, f"Input features {C} doesn't match num_features {self.num_features}"

        # Apply GroupScaling1D
        input = self.group_scaling(input)

        if self.training:
            self.accum_count += 1
            if self.accum_count >= self.grad_accum_steps:
                self.iters.add_(1)
                self.accum_count = 0

            output = SyncPowerFunction.apply(input, self.weight if self.affine else None, self.bias if self.affine else None,
                                         self.running_phi, self.eps, self.afwd, self.abkw, self.ema_gz, self.warmup_iters, self.iters,
                                         self.process_group, self.affine)
        else:
            # var = self.running_phi
            output = input * torch.rsqrt(self.running_phi + self.eps)
            if self.affine: # if not, do nothing.
                output = self.weight.reshape(1, 1, C) * output + self.bias.reshape(1, 1, C)

        return output # Shape: (B, T, C)

# -----------------------------------------------------------------------------

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = SyncPowerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = SyncPowerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024 # max sequence length
    vocab_size: int = 50257 # number of tokens: 50,000 BPE merges + 256 bytes tokens + 1 <|endoftext|> token
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = SyncPowerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # weight sharing scheme
        self.transformer.wte.weight = self.lm_head.weight

        # init params
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        # idx is of shape (B, T)
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
        # forward the token and posisition embeddings
        pos = torch.arange(0, T, dtype=torch.long, device=idx.device) # shape (T)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (T, n_embd)
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
        x = tok_emb + pos_emb
        # forward the blocks of the transformer
        for block in self.transformer.h:
            x = block(x)
        # forward the final layernorm and the classifier
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    @classmethod
    def from_pretrained(cls, model_type):
        """Loads pretrained GPT-2 model weights from huggingface"""
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        # n_layer, n_head and n_embd are determined from model_type
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
        # create a from-scratch initialized minGPT model
        config = GPTConfig(**config_args)
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
        # this means that we have to transpose these weights when we import them
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model

    def configure_optimizers(self, weight_decay, learning_rate, device_type):
        # start with all of the candidate parameters (that require grad)
        param_dict = {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        if master_process:
            print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
            print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == "cuda"
        if master_process:
            print(f"using fused AdamW: {use_fused}")
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer

# -----------------------------------------------------------------------------
def load_tokens(filename):
    npt = np.load(filename)
    npt = npt.astype(np.int32) # added after video
    ptt = torch.tensor(npt, dtype=torch.long)
    return ptt

class DataLoaderLite:
    def __init__(self, B, T, process_rank, num_processes, split):
        self.B = B
        self.T = T
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in {'train', 'val'}

        # get the shard filenames
        data_root = "./../edu_fineweb10B"
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s] # listing out shards file in the data_root dir
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        self.shards = shards
        assert len(shards) > 0, f"no shards found for split {split}"
        if master_process:
            print(f"found {len(shards)} shards for split {split}")

        # NEW IMPL
        self.current_shard = 0
        self.current_position = self.B * self.T * self.process_rank

        self.reset()

    def reset(self):
        # state, init at shard zero
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T * self.process_rank

    def set_state(self, state):
        self.current_shard = state['current_shard']
        self.current_position = state['current_position']
        self.tokens = load_tokens(self.shards[self.current_shard])

    def get_state(self):
        return {
            'current_shard': self.current_shard,
            'current_position': self.current_position
        }

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B * T * self.num_processes
        # if loading the next batch would be out of bounds, advance to next shard
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_position = B * T * self.process_rank
        return x, y

def setup_logging(project_name, args):
    logger = logging.getLogger(__name__)
    dir_name = "./log"
    os.makedirs(dir_name, exist_ok=True)
    print(f"Directory '{dir_name}' {'already exists' if os.path.exists(dir_name) else 'was created'}.")

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.FileHandler(f"log/debug_{ddp_rank}.log"),
            logging.StreamHandler(),
        ],
    )

    if master_process:
        wandb.init(project=project_name, config=args, dir="./../")
        run_name = wandb.run.name
        wandb_id = wandb.run.id
        logger.setLevel(logging.INFO)
        print(f"Starting new run: {run_name}")
    else:
        run_name = ""
        wandb_id = ""
        logger.setLevel(logging.ERROR)

    return logger, run_name, wandb_id

def resume_logging(project_name, run_id, args):
    logger = logging.getLogger(__name__)
    dir_name = "./log"
    os.makedirs(dir_name, exist_ok=True)
    print(f"Directory '{dir_name}' {'already exists' if os.path.exists(dir_name) else 'was created'}.")

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.FileHandler(f"log/debug_{ddp_rank}.log"),
            logging.StreamHandler(),
        ],
    )

    if master_process:
        wandb.init(project=project_name, id=run_id, resume="must", config=args, dir='./../')
        run_name = wandb.run.name
        logger.setLevel(logging.INFO)
        print(f"Resuming run: {run_name}")
    else:
        run_name = ""
        logger.setLevel(logging.ERROR)

    return logger, run_name

def log_metrics(metrics):
    if master_process:
        wandb.log(metrics)

def save_checkpoint(model, optimizer, step, val_loss, run_name, train_loader_state, val_loader_state, wandb_id):
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'step': step,
        'val_loss': val_loss,
        'run_name': run_name,
        'train_loader_state': train_loader_state,
        'val_loader_state': val_loader_state,
        'wandb_id': wandb_id,
    }

    checkpoint_path = os.path.join(log_dir, f"checkpoint_{step}.pt")
    torch.save(checkpoint, checkpoint_path)
    return checkpoint_path

def load_checkpoint(checkpoint_path, model, optimizer, train_loader, val_loader):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    train_loader.set_state(checkpoint['train_loader_state'])
    val_loader.set_state(checkpoint['val_loader_state'])

    return checkpoint['step'], checkpoint['val_loss'], checkpoint['run_name'], checkpoint['wandb_id']

# logging powernorm stats
def log_powernorm_stats(model):
    stats = {}
    for name, module in model.named_modules():
        if isinstance(module, SyncPowerNorm):
            running_phi = module.get_buffer('running_phi')
            if running_phi is not None:
                stats[f"{name}/running_phi"] = running_phi.mean().item()
            ema_gz = module.get_buffer('ema_gz')
            if ema_gz is not None:
                stats[f"{name}/ema_gz"] = ema_gz.mean().item()
            iters = module.get_buffer('iters')
            if iters is not None:
                stats[f"{name}/iters"] = iters.item()
    return stats

# -----------------------------------------------------------------------------
# helper function for HellaSwag eval
# takes tokens, mask, and logits, returns the index of the completion with the lowest loss

def get_most_likely_row(tokens, mask, logits):
    # evaluate the autoregressive loss at all positions
    shift_logits = (logits[..., :-1, :]).contiguous()
    shift_tokens = (tokens[..., 1:]).contiguous()
    flat_shift_logits = shift_logits.view(-1, shift_logits.size(-1))
    flat_shift_tokens = shift_tokens.view(-1)
    shift_losses = F.cross_entropy(flat_shift_logits, flat_shift_tokens, reduction='none')
    shift_losses = shift_losses.view(tokens.size(0), -1)
    # now get the average loss just for the completion region (where mask == 1), in each row
    shift_mask = (mask[..., 1:]).contiguous() # we must shift mask, so we start at the last prompt token
    masked_shift_losses = shift_losses * shift_mask
    # sum and divide by the number of 1s in the mask
    sum_loss = masked_shift_losses.sum(dim=1)
    avg_loss = sum_loss / shift_mask.sum(dim=1)
    # now we have a loss for each of the 4 completions
    # the one with the lowest loss should be the most likely
    pred_norm = avg_loss.argmin().item()
    return pred_norm

# -----------------------------------------------------------------------------
# simple launch:
# python train_gpt2.py
# DDP launch for e.g. 8 GPUs:
# torchrun --standalone --nproc_per_node=8 train_gpt2.py

# run the training loop

# set up DDP (distributed data parallel).
# torchrun command sets the env variables RANK, LOCAL_RANK, and WORLD_SIZE
ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
if ddp:
    # use of DDP atm demands CUDA, we set the device appropriately according to rank
    assert torch.cuda.is_available(), "for now i think we need CUDA for DDP"
    init_process_group(backend='nccl')
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
else:
    # vanilla, non-DDP run
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    # attempt to autodetect device
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device = "mps"
    print(f"using device: {device}")

# -----------------------------------------------------------------------------

# GPTesla - 111M param setup in comment. Modification to make lighter training requirement needed
config = {
    "weight_decay": 0.1,
    # "lr_scheduler_type": "cosine",
    "gradient_accumulation_steps": (2**16 * 7) // (56 * 1024 * ddp_world_size),  # total_batch_size // (B * T * ddp_world_size
    "max_eval_steps": 20,
    "seq_length": 1024,

    # New centralised parameters
    "project_name": "shng2025/GPT-Valkyrie_PN-124m",
    "total_batch_size": 2**16 * 7, # temporarily because 6 GPUs  # 2**19, ~0.5M, in number of tokens
    "micro_batch_size": 56,
    "max_lr": 6e-4,
    "min_lr": 6e-5,  # 10% of max_lr // not used, as we are using weight_decay instead
    "warmup_steps": 715,
    "max_steps": 21797, # had to be scaled up after 2001st step, as memory ran out when DDP
    "val_every": 500,           # EVALUATION
    "generate_every": 500,      # EVALUATION
    "hellaswag_every": 500,     # EVALUATION
    "save_every": 2000,           # SAVE CHECKPOINTING   
    "log_dir": "./log",
    "device": "auto",  # "auto", "cpu", "cuda", or "mps"
    "use_compile": True,
    "grad_clip": 1.0,
    "num_return_sequences": 4,
    "max_generate_length": 32,
    "top_k": 50,

    "resume_from_checkpoint": False,
}

args = Namespace(**config)
samples_per_step = torch.cuda.device_count() * args.micro_batch_size

project_name = args.project_name

# Logging - DEPRECATED
if master_process:
    pass
    # run_name, wandb_id = setup_logging(project_name.split("/")[1])
    # print(f"Weights and Biases run name: {run_name}")
    # print(f"Weights and Biases run id  : {wandb_id}")

# -----------------------------------------------------------------------------

# added after video, pytorch can be serious about it's device vs. device_type distinction
device_type = "cuda" if device.startswith("cuda") else "cpu"

torch.manual_seed(1337)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1337)

enc = tiktoken.get_encoding("gpt2")

total_batch_size = args.total_batch_size # 2**19, ~0.5M, in number of tokens
B = args.micro_batch_size # micro batch size
T = args.seq_length # sequence length
assert total_batch_size % (B * T * ddp_world_size) == 0, "make sure total_batch_size is divisible by B * T * ddp_world_size"
grad_accum_steps = total_batch_size // (B * T * ddp_world_size)
if master_process:
    print(f"total desired batch size: {total_batch_size}")
    print(f"=> calculated gradient accumulation steps: {grad_accum_steps}")

train_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train")
val_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val")

torch.set_float32_matmul_precision('high')

# create model
model = GPT(GPTConfig(vocab_size=50304))
# model = GPT.from_pretrained("gpt2") # or init from OpenAI GPT-2
model.to(device)
use_compile = True # torch.compile interferes with HellaSwag eval and Generation. TODO fix
if use_compile:
    model = torch.compile(model)
if ddp:
    model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module if ddp else model # always contains the "raw" unwrapped model

max_lr = args.max_lr
min_lr = max_lr * args.weight_decay
warmup_steps = args.warmup_steps
max_steps = args.max_steps # 19,073 steps is ~1 epoch, if data is 10B tokens and batch size 0.5M tokens
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_steps:
        return max_lr * (it+1) / warmup_steps
    # 2) if it > lr_decay_iters, return min learning rate
    if it > max_steps:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
    return min_lr + coeff * (max_lr - min_lr)

# optimize!
optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay, learning_rate=args.max_lr, device_type=device_type)

# create the log directory we will write checkpoints to and log to
log_dir = args.log_dir
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"log.txt")
with open(log_file, "w") as f: # open for writing to clear the file
    pass

"""
# Initialize HuggingFace repository
if master_process:
    new_branch_name = run_name
    create_branch(project_name, repo_type="model", branch=new_branch_name)
    hf_repo = Repository("./", clone_from=project_name, revision=run_name)
"""

# Training loop
starting_step = 0

if args.resume_from_checkpoint:
    checkpoint_dir = args.log_dir
    checkpoint_pattern = os.path.join(checkpoint_dir, "checkpoint_*.pt")
    checkpoint_files = glob.glob(checkpoint_pattern)
    latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
    checkpoint_path = latest_checkpoint

    # Use the load_checkpoint function here
    starting_step, val_loss, run_name, wandb_id = load_checkpoint(checkpoint_path, raw_model, optimizer, train_loader, val_loader)
    starting_step += 1 # to make sure step 80 isn't repeated

    logger, run_name = resume_logging(project_name.split("/")[1], wandb_id, args)
    print(f"Resuming from checkpoint: {checkpoint_path}")
    print(f"Weights and Biases run name: {run_name}")
    print(f"Resuming from step: {starting_step}")

    # Initialize HuggingFace repository <-- UNSURE IF NEEDED
    if master_process:
        new_branch_name = run_name
        # create_branch(project_name, repo_type="model", branch=new_branch_name)
        hf_repo = Repository("./", clone_from=project_name, revision=run_name)

    # Local subprocess for git pulling and checking out the newest branch
    if master_process:
        import subprocess
        subprocess.run(["git", "fetch", "origin"])
        subprocess.run(["git", "checkout", new_branch_name])
        subprocess.run(["git", "pull", "origin", new_branch_name])

    if master_process:
        print(f"Resuming from checkpoint at step {starting_step}")
else:
    starting_step = 0
    logger, run_name, wandb_id = setup_logging(project_name.split("/")[1], args)
    print(f"Weights and Biases run name: {run_name}")
    print(f"Weights and Biases run id  : {wandb_id}")

    # Initialize HuggingFace repository
    if master_process:
        new_branch_name = run_name
        create_branch(project_name, repo_type="model", branch=new_branch_name)
        hf_repo = Repository("./", clone_from=project_name, revision=run_name)

    if master_process:
        print(f"Starting new run: {run_name}")

# Training Loop
for step in range(starting_step, max_steps):
    t0 = time.time()
    last_step = (step == max_steps - 1)

    # once in a while evaluate our validation loss
    if step % args.val_every == 0 or last_step:
        model.eval()
        val_loader.reset()
        with torch.no_grad():
            val_loss_accum = 0.0
            val_loss_steps = args.max_eval_steps
            for _ in range(val_loss_steps):
                x, y = val_loader.next_batch()
                x, y = x.to(device), y.to(device)
                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                    logits, loss = model(x, y)
                loss = loss / val_loss_steps
                val_loss_accum += loss.detach()
        if ddp:
            dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
        log_metrics({"loss/validation": val_loss_accum.item()})

    # once in a while evaluate hellaswag
    if (step % args.hellaswag_every == 0 or last_step) and (not use_compile):
        num_correct_norm = 0
        num_total = 0
        for i, example in enumerate(iterate_examples("val")):
            # only process examples where i % ddp_world_size == ddp_rank
            if i % ddp_world_size != ddp_rank:
                continue
            # render the example into tokens and labels
            _, tokens, mask, label = render_example(example)
            tokens = tokens.to(device)
            mask = mask.to(device)
            # get the logits
            with torch.no_grad():
                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                    logits, loss = model(tokens)
                pred_norm = get_most_likely_row(tokens, mask, logits)
            num_total += 1
            num_correct_norm += int(pred_norm == label)
        # reduce the stats across all processes
        if ddp:
            num_total = torch.tensor(num_total, dtype=torch.long, device=device)
            num_correct_norm = torch.tensor(num_correct_norm, dtype=torch.long, device=device)
            dist.all_reduce(num_total, op=dist.ReduceOp.SUM)
            dist.all_reduce(num_correct_norm, op=dist.ReduceOp.SUM)
            num_total = num_total.item()
            num_correct_norm = num_correct_norm.item()
        acc_norm = num_correct_norm / num_total
        if master_process:
            log_metrics({"hella/swag": acc_norm, "hella/correct norm": num_correct_norm, "hella/num total": num_total})
            print(f"HellaSwag accuracy: {num_correct_norm}/{num_total}={acc_norm:.4f}")
            with open(log_file, "a") as f:
                f.write(f"{step} hella {acc_norm:.4f}\n")

    # once in a while generate from the model (except step 0, which is noise)
    if ((step > 0 and step % args.generate_every == 0) or last_step) and (not use_compile):
        model.eval()
        num_return_sequences = args.num_return_sequences
        max_length = args.max_generate_length
        tokens = enc.encode("Hello, I'm a language model,")
        tokens = torch.tensor(tokens, dtype=torch.long)
        tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
        xgen = tokens.to(device)
        sample_rng = torch.Generator(device=device)
        sample_rng.manual_seed(42 + ddp_rank)
        while xgen.size(1) < max_length:
            # forward the model to get the logits
            with torch.no_grad():
                with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
                    logits, loss = model(xgen) # (B, T, vocab_size)
                # take the logits at the last position
                logits = logits[:, -1, :] # (B, vocab_size)
                # get the probabilities
                probs = F.softmax(logits, dim=-1)
                # do top-k sampling of 50 (huggingface pipeline default)
                # topk_probs here becomes (5, 50), topk_indices is (5, 50)
                topk_probs, topk_indices = torch.topk(probs, args.top_k, dim=-1) # top_k == 50 in this case
                # select a token from the top-k probabilities
                # note: multinomial does not demand the input to sum to 1
                ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
                # gather the corresponding indices
                xcol = torch.gather(topk_indices, -1, ix) # (B, 1)
                # append to the sequence
                xgen = torch.cat((xgen, xcol), dim=1)
        # print the generated text
        for i in range(num_return_sequences):
            tokens = xgen[i, :max_length].tolist()
            decoded = enc.decode(tokens)
            print(f"rank {ddp_rank} sample {i}: {decoded}")

    if (step % args.save_every == 0 or last_step) and step > 0:
        if master_process:
            # Save checkpoint and push to HuggingFace
            checkpoint_path = save_checkpoint(
                raw_model, optimizer, step, val_loss_accum.item(), run_name,
                train_loader.get_state(), val_loader.get_state(), wandb_id
            )
            hf_repo.push_to_hub(commit_message=f"Checkpoint at step {step}")
            print(f"Saved checkpoint and pushed to HuggingFace at step {step}")

    # do one step of the optimization
    model.train()
    optimizer.zero_grad()
    loss_accum = 0.0

    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        # added after video, this field is also used by the forward pass.
        if ddp:
            model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
        with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
            logits, loss = model(x, y)
        # we have to scale the loss to account for gradient accumulation,
        # because the gradients just add on each successive backward().
        # addition of gradients corresponds to a SUM in the objective, but
        # instead of a SUM we want MEAN. Scale the loss here so it comes out right
        loss = loss / grad_accum_steps
        loss_accum += loss.detach()

        # Add NaN check here
        if not torch.isnan(loss):
            loss.backward()
        else:
            print(f"NaN loss detected at step {step}, micro_step {micro_step}. Skipping backward.")
            if master_process:
                wandb.alert(
                    title="NaN Loss Detected",
                    text=f"NaN loss detected at step {step}, micro_step {micro_step}. Skipping backward.",
                    level=wandb.AlertLevel.WARN
                )
            break

    if ddp:
        dist.all_reduce(loss_accum, op=dist.ReduceOp.AVG)

    # Add another NaN check here before optimizer step
    if not torch.isnan(loss_accum):
        norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) # grad_clip == 1.0 by default
        # determine and set the learning rate for this iteration
        lr = get_lr(step)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        optimizer.step()
    else:
        print("successfully broke")
        print(f"NaN accumulated loss detected at step {step}. Skipping optimizer step.")
        wandb.alert(
            title="NaN Accumulated Loss Detected",
            text=f"NaN accumulated loss detected at step {step}. Skipping optimizer step.",
            level=wandb.AlertLevel.ERROR
        )
        break

    if device_type == "cuda":
        torch.cuda.synchronize() # wait for the GPU to finish work
    t1 = time.time()
    dt = t1 - t0 # time difference in seconds
    tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
    tokens_per_sec = tokens_processed / dt
    if master_process:
        pn_stats = log_powernorm_stats(raw_model)
        log_metrics({
            "lr": lr, # get_lr()
            "samples": step * samples_per_step,
            "steps": step,
            "loss/train": loss_accum.item(),
            # file specific addition
            "global gradient norm": norm,
            "dt": dt,
            "tok per sec": tokens_per_sec,
            **pn_stats  # Unpack the PowerNorm stats into the metrics
        })
        print(f"step {step:5d} | loss: {loss_accum.item():.6f} | lr {lr:.4e} | norm: {norm:.4f} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
        with open(log_file, "a") as f:
            f.write(f"{step} train {loss_accum.item():.6f}\n")

if ddp:
    destroy_process_group()

# Final save and push
if master_process:
    final_checkpoint_path = checkpoint_path = save_checkpoint(raw_model, optimizer, step, val_loss_accum.item(), run_name, train_loader.get_state(), val_loader.get_state(), wandb_id)
    hf_repo.push_to_hub(commit_message="Final model")
    print("Training completed. Final model pushed to HuggingFace.")
Ice-Citron commented 1 month ago

@lumliolum

i've double checked it for error as much as I could and I think its fine. I've even checked it by printing out the physical copy, etc.

Tbh, whats weird about Powernorm is that, I'm not sure why is it even using the "Groupscaling1d" function. Which when applied with groupnum=1 by default, its basically just layer normalisation.

The use of layer normalisation, or dividing by the second moment (varience) isn't clarified in the paper yet it's used here.

When this isn't used, compute or tokens/sec is sped up by 11%. But this comes at a cost that at around step 2000 or 3000, NaN errors will occur, and your code crash, or just the model becomes dead because its infested with NaN bugs.

Ice-Citron commented 1 month ago

@lumliolum

What are you using batch normalisation for though. Transformers?

Tbh, I think this is just one of the case which the author was using a dataset of lower quality, which just contains very simillar type of words, and hence he got a higher metric because PN is likely suitable for smaller models and when dealing with very simillar data. Or in other words, he was overfitting.

I'm using fineweb_edu on the other hand, a dataset compiled by huggingface very recently with high quality text corpus (source file below). And it likely had lots of varience, or different types of data, hence why the gradients begun to explode close in the middle.

Tbh, in the end, I don't really like the idea of calculating population statistics, especially for something like a transformer, that tends to be especially large (like my 124m model is the base size, and even then is already at 450 MB), hence why they are called LLMs.

I think your just much better off using Root mean square normalisation or Layer normalisation instead. Following convention. Because they are much simpler, requires less compute, and seem to get the job done too, considering the whole idea of layer normalisation was to reduce internal covariate shift to begin with.

And powernorm's added complexity I feel is just unwanted for.


By the way, almost forgot, you need to use this for the dataset, or my dataset (more like Andrej Karparthy's dataset). Which it basically just tokenizes the whole dataset and puts them into shards for you to readily use:

https://huggingface.co/shng2025/GPT-Valkyrie_PN-124m/blob/denim-lake-75/fineweb.py