Open peregilk opened 6 months ago
@Helw150 said it worked out of the box. Just configs I think
Thats fantastic, @dlwh! Great if you could share your configs @Helw150.
I must admit I have not dug into the details here yet, but I understood the biggest architectural changes were using a larger tokenizer, and adding GQA to the smaller models. I havent seen GQA used in any of the Levanter models, but found a post saying it was supported. Can this also just be enabled through the configs?
I also read a post about them doing some masking on longer sequences so that the attention did not "spill over" to new documents.
The model seems to start training with:
data:
tokenizer: "meta-llama/Meta-Llama-3-8B"
model:
type: llama
initialize_from_hf: "meta-llama/Meta-Llama-3-8B"
use_hf_model_config: true
However, I keep getting the message: "The tokenizers appear to be different. You may want to check this."
.
Not really sure what is causing this.
@dlwh. Unfortuantely, I can not seem to get it to work right out of the box. The model is training, but when trying to train on a domain specific corpus, the loss is starting way too high, and never fully recovers.
I am pretty sure the issue is the vocab size here. I can not seem to override the vocab size in the model config.
This line seem to return the default Llama tokenizer: https://github.com/stanford-crfm/levanter/blob/bd2aad66a2c301a52b52a65b455617aa2e452ba6/src/levanter/main/train_lm.py#L62
While it is overwritten later, I think this is the main issue.
I have tried both reading the configs from HF, and creating them from scratch.
Please advice.
ok i'll try to take a look this weekend. Do you have a full config you can use a reproducer by any chance?
Awesome. Here are the config I have been using. Just replaced the urls.
data:
train_urls:
- "gs://mydatabucket/train-shard-{0001..0147}-of-0147.json.gz"
validation_urls:
- "gs://mydatabucket/NCC_plus_scandi/validation-shard-0001-of-0001.json.gz"
cache_dir: "gs://mycachebucket/tokenized/llama3hfconfigfalse/"
tokenizer: "meta-llama/Meta-Llama-3-8B"
model:
type: llama
seq_len: 2048
hidden_dim: 4096
intermediate_dim: 14336
num_layers: 32
num_heads: 32
num_kv_heads: 8
initializer_range: 0.02
use_flash_attention: true
initialize_from_hf: "meta-llama/Meta-Llama-3-8B"
use_hf_model_config: false
trainer:
wandb:
entity: "myentity"
project: "myproject"
tags: ["llama3"]
name: north-llamatre-hfconfigfalse
mp: p=f32,c=bfloat16
train_batch_size: 256
num_train_steps: 10000
steps_per_eval: 250
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
checkpointer:
base_path: "gs://mycheckpointbucket/north-llama3-hfconfigfalse/checkpoints"
keep:
- every: 1000
optimizer:
learning_rate: 1.2e-5
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 1000
hf_save_steps: 5000
hf_save_path: "gs://myhfbucket/north-llama3-hfconfigfalse/hf"
I have also tried setting
use_hf_model_config: true
This gave the same result.
What I am seeing can be illustrated here:
The red line is the loss of a Mistral model. The grey line is from Llama3. Apart from that, the settings are identical, and they are both trying to use the HF tokenizer. The pattern is very similar to what we are seeing with just hot-swappng to a new tokenizer.
Do you have a reproduction of a case where the Levanter implementation gives you a different prediction than the HuggingFace implementation? As an example, here's a round trip test I used to verify the Whisper implementation https://github.com/stanford-crfm/levanter/blob/407d54b63e141d6261754986453bb1ffd1c8afb7/tests/whisper_test.py#L130
The only architectural change in LLama 3 is the Grouped Query attention - which is supported here: https://github.com/stanford-crfm/levanter/blob/407d54b63e141d6261754986453bb1ffd1c8afb7/src/levanter/models/llama.py#L236
I've exported a few Llama 3 finetunes from Levanter to HuggingFace successfully and the models seem to work as expected for inference, so it's unclear to me whether the above case suggests a bug or is a function of the much larger vocab size of LLama 3 v.s. Mistral. I'm not sure what the data mix is above, but if it's multilingual it's also likely Mistral starts from lower loss because it's more explicitly designed for Multilinguality.
If you send over a case where HuggingFace and Levanter output different logits for the Llama 3 weights, I'd be happy to take on the debugging from there!
I am trying to debug this and test on downstream tasks by exporting to HF. However, I noticed that for llama3, no tokenizer.model file is created when saving to HF. Have you experienced this @Helw150?
Edit: I see the reason for this is that the HF repos does not contain any tokenizer.model-file.
Reopening this. I have trained a bit more, and I am really not satisfied with the result, even if the train/eval loss looks fine.
Do you have a working llama3 config-file @Helw150. I want to double check if I have made any mistakes here.
Hi!
My use case is a bit non-standard (training multi-modal encoders) so I'm not sure my configs will help so much. If you want to check them anyways, you can find them on the will/distill
branch tagged with via_*
! In these cases, I'm leaving Llama frozen but still need to get gradients from it. I've done runs with both Llama 2 and Llama 3 and haven't seen any surprising looking issues when switching to Llama 3!
Could you give a bit more details about the issue you are facing? Does it seem like the model isn't training properly? Or is it that the results aren't satisfactory?
If it's the latter, additional context (e.g. specific symptoms, expected behavior) would help for me to understand whether there's an underlying bug that could cause this or if it's a matter of hyperparameters/underlying capabilities!
What revision/commit were you using to train? My usage of the TPU splash attention had/has a bug that messed everything up. I'm like 60% sure I know how to fix (and you can probably fix your checkpoints post-hoc) but I need another day or so. If you want to try something, can you pre-multiply all of the q_proj by sqrt(headdim). I haven't verified that yet but I strongly suspect
Ah yes, worth noting that I haven't pulled in the Splash Attention changes yet
splash attention is currently disabled so main is fine 🤞 right now
I was using splash attention, so that might have caused the error.
However, I was suspecting this to be a tokenizer-size issue. I remember also getting some warning about non-matching tokenizers here.
But I can retry this without splash, and see if that is related.
I believe splash is now fixed in latest main, but it's now off by default.
Can you try
--model.attn_backend splash
and
--model.attn_backend jax_flash
and let me know if things seem ok?
Awesome! I have not been training for long, but in general my good runs have been starting with an eval-loss of around 2.5, while the broken runs have started on 6. In the latest main, this seems to start with a 2.5 loss both with and without flash attention. Looks very good.
For reference (in case other are having the same issue), the correct commands are uppercase: --model.attn_backend SPLASH --model.attn_backend JAX_FLASH
Splash automatically upscales to 32, since 16 is not working. I understand this is expected.
Awesome thanks for your patience.
Yeah, for whatever reason they don't support bf16 for attention with that kernel yet
the uppercase thing can be fixed by upgrading draccus to >=0.8.0
@peregilk Llama3 shouldn't work out of the box nicely, as it uses a different theta
for the RoPE scaling and configuring that isn't yet supported in levanter
. This issue should probably be re-opened. Even when I use the correct rope theta
I don't get reasonable results in levanter
(i.e. eval_lm
gives me a loss of ~7 on neutral pretraining datasets like SlimPajama). @dlwh any ideas?
that's not great. Probably need to spend some time in a debugger.
i probably won't get to this for at least a few days myself, but happy to provide some support
@dlwh any progress on this one? I was thinking of switching to Levanter from composer.
i don't really understand the issue. we have a unit test (which I recognize is not necessarily proof it's correct) and support rope scaling now. Does someone have a code that fails
ok i see. this is becoming a priority for me so i will try to tackle it by wednesday
I haven't fully tested it but can you try main. I added the new llama 3 rope stuff
@mayankjobanputra did you have a chance to try it?
@dlwh I haven't tried it yet. Still preprocessing the data and meanwhile writing some infra code around the framework. If everything goes smoothly I should be able to answer your question in 15ish days.
Do you have any plans for adding for supporting Llama-3? Any idea how complex this would be, apart from new configs?