Open ChrisDryden opened 4 months ago
To start off, I will first implement the layernorm forward in the backwards pass implementation and use the ln1 and ln2 values directly from that layernorm forward to get an initial working version of recalculating the values in the backwards pass.
In the above PR I was able to implement the reduced memory: Went from this with recompute set to 1:
allocating 1439 MiB for activations
val loss 4.503491
allocating 237 MiB for parameter gradients
allocating 30 MiB for activation gradients
allocating 474 MiB for AdamW optimizer state m
allocating 474 MiB for AdamW optimizer state v
allocating 474 MiB for master copy of params
To this:
allocating 1307 MiB for activations
val loss 4.504488
allocating 237 MiB for parameter gradients
allocating 30 MiB for activation gradients
allocating 474 MiB for AdamW optimizer state m
allocating 474 MiB for AdamW optimizer state v
allocating 474 MiB for master copy of params
The PR was merged but still needs the second step of making a simplified kernel that doesnt recompute everything and reuses the values calculated in the forwards pass
@ngc92 Did an analysis of the areas that take up the most memory and its impact on the amount of batches that can be used and found that one of the largest contributors was the memory associated with the Layernorm recomputations:
This Issue will track the implementation of adding the ability similar to how the GELU is recalculated in the backwards pass to recalculate the layernorm forwards activations so that we can reduce the memory.