This is still in progress / experimental, currently it is only implemented for normal gemma MQA attention layers, and no parallelism is added yet for backward pass.
Since we need to remember all activations from all layers, the forward pass was also reimplemented with a new activation data structure.
This is still in progress / experimental, currently it is only implemented for normal gemma MQA attention layers, and no parallelism is added yet for backward pass.
Since we need to remember all activations from all layers, the forward pass was also reimplemented with a new activation data structure.