Previously with gradient accumulation, the loss was divided by the number of gradient accumulation steps. This normalised loss was then logged. From a logging perspective, each train step is equivalent to one gradient accumulation step and therefore the normalisation was not needed.
Additionally, I refactored the gradient clipping in two ways. Firstly, we were not using the FSDP gradient clipping. This is now integrated but needs to be made configurable. This I will do in another PR.
Secondly, the gradient clipping is only calculated and logged after # gradient accumulation steps. Since garden clipping already syncs the gradient norm across all the ranks, we don't have to apply the reduce operation on them anymore.
Previously with gradient accumulation, the loss was divided by the number of gradient accumulation steps. This normalised loss was then logged. From a logging perspective, each train step is equivalent to one gradient accumulation step and therefore the normalisation was not needed.
Additionally, I refactored the gradient clipping in two ways. Firstly, we were not using the FSDP gradient clipping. This is now integrated but needs to be made configurable. This I will do in another PR.
Secondly, the gradient clipping is only calculated and logged after # gradient accumulation steps. Since garden clipping already syncs the gradient norm across all the ranks, we don't have to apply the reduce operation on them anymore.