Open daviswer opened 4 months ago
Plan is to move the include_embeds=True
versions of Llama/GPTBigCode/generate() into fms-extras
. Once that is done I'll update the relevant imports here and then we can push this in
I've pulled all the include_embeds
stuff out of fms
into here. We now have EmbedLLaMA
and EmbedGPTBigCode
subclasses that override the corresponding forward function, and an altered version of generate
for use only with this script. We register the subclassed models for use with get_model
in the training script.
Code is ready for review - mypy errors are import errors, it doesn't have it doesn't like the local import of fms-extras
andtrain_speculator_utils
. Should I move the speculator subfolder under fms_fsdp
so that I can use an absolute path for that?
Add support for speculator training, piggybacking off the existing training utilities.
Training script and speculator-specific utilities are inside the new
speculator
subfolder.Uses distributed setup, checkpointing, and dataloaders from this repo. Adds speculator-specific fields to the training config file (to be ignored during non-speculator training). It might make more sense to pull these new fields out into a separate config subclass under speculator utilities - open to suggestions.
Uses speculator architecture from fms-extras.
Uses altered Llama-7b and
generate()
function from base fms, allowing the speculator to access embedding vectors, not just logits/token predictions.Do not merge this until that issue can be resolved.