Open john-hewitt opened 1 year ago
I think this is because David changed this from the original. It looks like https://github.com/stanford-crfm/mistral/blob/bf9eff08e83f4d5703b69dfcb6c18e8e35a00a6d/src/corpora/auto.py#L94 does the right thing ? So if you just build the conventional Hugging Face cache (instead of David's custom index) ... with get_auto_dataset
it should work fine and create a standard tokenized Hugging Face cache with an extra validation set.
Basically older Mistral just built a conventional Hugging Face cache and David created a new custom data handling setup and I guess didn't add in creating a validation set ...
I've started this branch: https://github.com/stanford-crfm/mistral/tree/mistral-flash-dec-2022
This should have Mistral Feb 2022 code + some bug fixes + has worked with Flash Attention
I'm on vacation mode but I am happy to help you get this branch working ... you will need to install flash attention and a specially modified Hugging Face as well ...
Some instructions on getting this working, (remember use branch: https://github.com/stanford-crfm/mistral/tree/mistral-flash-dec-2022)
conda create -n mistral python=3.8.12 pytorch=1.11.0 torchdata cudatoolkit=11.3 -c pytorch
conda activate mistral
pip install -r setup/pip-requirements.txt
I think this will work with newer PyTorch, etc ... but you need to make sure you build Flash Attention with whatever you are using ...
Install transformers from Git (https://github.com/huggingface/transformers), replace src/transformers/models/gpt2/modeling_gpt2.py
with the version checked into this branch in the transformers
dir in the top level directory of this repo
When creating an environment, make sure to install Flash Attention: (https://github.com/HazyResearch/flash-attention) ... you made need to roll back to this commit: f515c77f2528b5062ebcc6c905c8817ca0ac0ad1 ... last time I tried to get this working it wasn't because of issues with newer versions of Flash Attention but they may've been resolved in main by now ... but I rolled back to that commit and it was fine ...
Please let me know if you run into any issues and we can clean this branch + instructions up ... but if all goes well should get super fast Flash Attention GPT2 training which is something like 2x faster ...
In the future we should think about reconciling this branch with current main ... but if you just want something working in the next day this is quickest route ...
Sample command:
Note add a file called hostfile
in the top level directory (even a blank one if just using one machine)
deepspeed --hostfile hostfile --num_gpus 8 --num_nodes 1 --master_addr sphinx4 train.py --config conf/your_config.yaml --nnodes 1 --nproc_per_node 8 --training_arguments.per_device_train_batch_size 16 --training_arguments.deepspeed conf/deepspeed/z2-small-bf16-conf.json --run_id mistral-w-flash-demo
You need to use bf16 ... a bad feature of this branch right now is this is just hard-coded here: https://github.com/stanford-crfm/mistral/blob/3a7dfac4836e760b6a3afb880f8f79585c357281/src/args/training_args.py#L67
So it'd be a good idea to make this more transparent ... this branch is sort of my personal experimentation that I got running and could use some clean up ...
Flash Attention requires bf16 or fp16 ... and you need bf16 for the stability ...
@J38 Hello, I also had the same issue of code not working for openwebtext due to missing validation set, so I tried your solution above. But I encountered the error "ImportError: cannot import name 'LMDataCollator' from 'src.core.trainer'". It looks like the src.core.trainer file in the branch https://github.com/stanford-crfm/mistral/tree/mistral-flash-dec-2022 does not have a class called LMDataCollator. Could you please help look into that?
Can you provide more details about what is causing that error (e.g. what line is failing in what file)? The branch is older code before changes were made, so it should not require LMDataCollator. Are you pre-training from scratch or trying to fine-tune a model trained with main branch code?
I guess it is this line: https://github.com/stanford-crfm/mistral/blob/3a7dfac4836e760b6a3afb880f8f79585c357281/train.py#L39
Yes it is this line, and I believe LMDataCollator is used in line 158 of this file. But I'm able to fix the dev set problem by adding a few lines of code on the main branch so I think the issue is resolved.
I tried reverting train.py
to the February 2022 version, does that help ?
Describe the bug After building index for openwebtext, building the trainer fails (at line 161 of
train.py
) because novalidation
dataset is constructed. I believe this is because thelm_dataset
object is built with huggingface'sload_dataset
on theopenwebtext
named dataset, and it has no validation split. Thevalidation_ratio
quinine config option is only used in building thecustom_eval_datasets
, not thelm_dataset
object, so it is not used to portion out part ofopenwebtext
as a validation set.To Reproduce Replace
datasets/wikitext2.yaml
withdatasets/openwebtext.yaml
inmistral-micro.yaml
(and make other artefact location changes) and runExpected behavior No failure occurs at line
161
oftrain.py
whenlm_dataset['validation']
is expressed.