foundation-model-stack / fms-acceleration

🚀 Collection of libraries used with fms-hf-tuning to accelerate fine-tuning and training of large models.
Apache License 2.0
6 stars 12 forks source link

Allow Kernels for Full FT and Non-Quantized PEFT #79

Closed fabianlim closed 2 months ago

fabianlim commented 3 months ago

Description

This PR

  1. upgrades framework to perform OR logic when activating plugins
  2. creates a FastKernelsAccelerationPlugin that is an improved version over FastQuantizedPeftAccelerationPlugin
    • it can add kernels individually
    • it can be activated under an training stanza or a peft.quantized stanza
  3. Add FOAK support to Full-Finetuning and Standard PEFT benchmarks
  4. FOAK support on 1 additional models
    • GPTBigCode
      • Note that due to GPTBigCode architecture limitations only FastCrossEntropyLoss is supported in this PR. Additional support will be tracked [placeholder issue]
  5. Bug fix to ModelPatcher to address multiple reloads to the same target path
    • This affected the proper patching of FastCrossEntropyLoss

Improvements to Full Finetuning

7% Improvement from following kernels (FastCrossEntropyLoss, FastRMSNorm, FastRoPE) Framework Model num gpus batch size throughput (toks/s) Improvement %
fullFT Mistral7B 1 4 2910 base
foak-fullFT Mistral7B 1 4 3218 10.5
PEFT Mistral7B 1 4 3345 base
foak-PEFT Mistral7B 1 4 3797 13.5
Framework Model num gpus batch size throughput (toks/s) Improvement %
fullFT Mistral7B 2 4 2886 base
foak-fullFT Mistral7B 2 4 3093 7
PEFT Mistral7B 2 4 3227 base
foak-PEFT Mistral7B 2 4 3620 12
Compatibility Matrix with Mixed Precision torch_dtype Mixed Precision Full-FT-FOAK PEFT-FOAK QPEFT-FOAK
FLOAT16 - ✗ Not Allowed ✗ ✗
FLOAT16 FP16 ValueError:
Attempting to
unscale FP16 gradients.
See here
Compatible Compatible
BFLOAT16 - ✗ ✗ ✗
BFLOAT16 BF16 Compatible Compatible Less Performant

Regression Test for Loss, Memory, Throughput

Running our alpaca benchmarks for most experiments in bfloat16 (except GPTQ-LoRA in float16. See issue). We see no significant regression in performance.

_Note an outlier in the comparison plots show an anomalous memory increase in a standard full-FT experiment on Mistral7B with no accelerations installed. Since it does not point to any issues with the code in this PR, it might be caused by some slight instability of the benchmarking run._

Bug Fix to Model Patcher

There is no significant change in performance of FOAK from the fix for the improper patching of FastCrossEntropyLoss, however there is a slight decrease in improvement observed (consistent with issue 70) compared to previous paddingfree+foak numbers.

FLAN (6000 samples) with PaddingFree

Before BugFix Framework Model num gpus batch size train_runtime (s) throughput (toks/s) Improvement %
BNB + foak Mistral7B 2 4 1068 1328 base
BNB + foak + paddingfree Mistral7B 2 4 605 2400 +43
GPTQ-LoRA + foak Mistral7B 2 4 1034 1372 base
GPTQ-LoRA + foak + paddingfree Mistral7B 2 4 587 2472 +43
With BugFix Framework Model num gpus batch size train_runtime (s) throughput (toks/s) Improvement
BNB + foak Mistral7B 2 4 1038 1368 base
BNB + foak + paddingfree Mistral7B 2 4 674 2106 +35
GPTQ-LoRA + foak Mistral7B 2 4 1035 1372 base
GPTQ-LoRA + foak + paddingfree Mistral7B 2 4 660 2160 +36

Note: Due to issues with FSDP-QLoRA in the latest transformers version (4.45.0dev) mentioned here, Granite with Fast Kernels will be addressed in a later PR instead.

TODO