foundation-model-stack / fms-fsdp

🚀 Efficiently (pre)training foundation models with native PyTorch features, including FSDP for training and SDPA implementation of Flash attention v2.
https://pytorch.org/docs/stable/fsdp.html
Apache License 2.0
116 stars 18 forks source link

[speculator training] Speculator training #35

Open daviswer opened 4 months ago

daviswer commented 4 months ago

Add support for speculator training, piggybacking off the existing training utilities.

Training script and speculator-specific utilities are inside the new speculator subfolder.

Uses distributed setup, checkpointing, and dataloaders from this repo. Adds speculator-specific fields to the training config file (to be ignored during non-speculator training). It might make more sense to pull these new fields out into a separate config subclass under speculator utilities - open to suggestions.

Uses speculator architecture from fms-extras.

Uses altered Llama-7b and generate() function from base fms, allowing the speculator to access embedding vectors, not just logits/token predictions. Do not merge this until that issue can be resolved.

daviswer commented 3 months ago

Plan is to move the include_embeds=True versions of Llama/GPTBigCode/generate() into fms-extras. Once that is done I'll update the relevant imports here and then we can push this in

daviswer commented 3 months ago

I've pulled all the include_embeds stuff out of fms into here. We now have EmbedLLaMA and EmbedGPTBigCode subclasses that override the corresponding forward function, and an altered version of generate for use only with this script. We register the subclassed models for use with get_model in the training script.

daviswer commented 3 months ago

Code is ready for review - mypy errors are import errors, it doesn't have fms-extras and it doesn't like the local import of train_speculator_utils. Should I move the speculator subfolder under fms_fsdp so that I can use an absolute path for that?