eole-nlp / eole

Open language modeling toolkit based on PyTorch
https://eole-nlp.github.io/eole
MIT License
42 stars 9 forks source link

`eole train` returns pydantic validation error for `TrainConfig` #5

Closed l-k-11235 closed 3 months ago

l-k-11235 commented 3 months ago

I have this error:

[2024-06-06 09:20:01,521 INFO] Loading checkpoint from models/llama3/llama3-8b
[2024-06-06 09:20:01,541 WARNING] You have a CUDA device, should run with -gpu_ranks
[2024-06-06 09:20:01,582 INFO] Get special vocabs from Transforms: {'src': [], 'tgt': []}.
[2024-06-06 09:20:01,743 INFO] Transforms applied: ['filtertoolong', 'onmt_tokenize', 'insert_mask_before_placeholder']
[2024-06-06 09:20:01,764 INFO] The first 10 tokens of the vocabs are:['<unk>', '!', '"', '#', '$', '%', '&', "'", '(', ')']
[2024-06-06 09:20:01,764 INFO] The decoder start token is: <s>
[2024-06-06 09:20:01,768 INFO] Option: tgt_subword_model, value: models/llama3/llama3-8b/bpe.model, overriding model: None
[2024-06-06 09:20:01,768 INFO] Option: src_subword_type, value: bpe, overriding model: none
[2024-06-06 09:20:01,768 INFO] Option: src_subword_model, value: /models/llama3/llama3-8b/bpe.model, overriding model: None
[2024-06-06 09:20:01,768 INFO] Option: gpt2_pretok, value: True, overriding model: False
[2024-06-06 09:20:01,768 INFO] Option: tgt_subword_type, value: bpe, overriding model: none
[2024-06-06 09:20:01,768 INFO] Option: tgt_seq_length, value: 2048, overriding model: 512
[2024-06-06 09:20:01,768 INFO] Option: src_seq_length, value: 2048, overriding model: 512
[2024-06-06 09:20:01,768 INFO] Option: tensorboard, value: True, overriding model: False
[2024-06-06 09:20:01,768 INFO] Option: tgt_word_vec_size, value: 500, overriding model: 4096
[2024-06-06 09:20:01,768 INFO] Option: decoder_type, value: rnn, overriding model: transformer_lm
[2024-06-06 09:20:01,768 INFO] Option: save_data, value: ./finetune/llama3-8b-finetune, overriding model: None
[2024-06-06 09:20:01,768 INFO] Option: tensorboard_log_dir, value: ./finetune/llama3-8b-finetune/logs/, overriding model: runs/eole
[2024-06-06 09:20:01,768 INFO] Option: src_vocab, value:/models/llama3/llama3-8b/vocab.txt, overriding model: None
[2024-06-06 09:20:01,768 INFO] Option: transforms, value: ['insert_mask_before_placeholder', 'onmt_tokenize', 'filtertoolong'], overriding model: ['filtertoolong']
[2024-06-06 09:20:01,768 INFO] Option: param_init, value: 0.0, overriding model: 0.1
[2024-06-06 09:20:01,768 INFO] Option: num_workers, value: 1, overriding model: 2
[2024-06-06 09:20:01,768 INFO] Option: gpu_ranks, value: [0], overriding model: []
[2024-06-06 09:20:01,768 INFO] Option: warmup_steps, value: 100, overriding model: 4000
[2024-06-06 09:20:01,768 INFO] Option: valid_batch_size, value: 20482, overriding model: 256
[2024-06-06 09:20:01,768 INFO] Option: lora_dropout, value: 0.05, overriding model: 0.0
[2024-06-06 09:20:01,768 INFO] Option: save_checkpoint_steps, value: 500, overriding model: 5000
[2024-06-06 09:20:01,768 INFO] Option: param_init_glorot, value: True, overriding model: False
[2024-06-06 09:20:01,768 INFO] Option: max_grad_norm, value: 0.0, overriding model: 5.0
[2024-06-06 09:20:01,768 INFO] Option: model_path, value: ./finetune/llama3-8b-finetune, overriding model: model
[2024-06-06 09:20:01,768 INFO] Option: learning_rate, value: 0.0001, overriding model: 1.0
[2024-06-06 09:20:01,768 INFO] Option: batch_size, value: 2048, overriding model: 896
[2024-06-06 09:20:01,768 INFO] Option: attention_dropout, value: [0.0], overriding model: [0.1]
[2024-06-06 09:20:01,768 INFO] Option: keep_checkpoint, value: 20, overriding model: -1
[2024-06-06 09:20:01,768 INFO] Option: lora_layers, value: ['linear_values', 'linear_query', 'linear_keys', 'final_linear'], overriding model: []
[2024-06-06 09:20:01,768 INFO] Option: adam_beta2, value: 0.998, overriding model: 0.999
[2024-06-06 09:20:01,768 INFO] Option: quant_type, value: bnb_NF4, overriding model: 
[2024-06-06 09:20:01,768 INFO] Option: zero_out_prompt_loss, value: True, overriding model: False
[2024-06-06 09:20:01,768 INFO] Option: lora_alpha, value: 8, overriding model: 1
[2024-06-06 09:20:01,768 INFO] Option: quant_layers, value: ['w_1', 'w_2', 'w_3', 'linear_values', 'linear_query', 'linear_keys', 'final_linear'], overriding model: []
[2024-06-06 09:20:01,768 INFO] Option: accum_count, value: [8], overriding model: [32]
[2024-06-06 09:20:01,768 INFO] Option: dropout, value: [0.0], overriding model: [0.3]
[2024-06-06 09:20:01,768 INFO] Option: train_from, value: models/llama3/llama3-8b, overriding model: None
[2024-06-06 09:20:01,768 INFO] Option: bucket_size, value: 32768, overriding model: 262144
[2024-06-06 09:20:01,768 INFO] Option: seed, value: 1234, overriding model: -1
Traceback (most recent call last):
  File "/usr/local/bin/eole", line 33, in <module>
    sys.exit(load_entry_point('EOLE', 'console_scripts', 'eole')())
  File "/workdir/eole/eole/bin/main.py", line 39, in main
    bin_cls.run(args)
  File "/workdir/eole/eole/bin/run/train.py", line 68, in run
    train(config)
  File "/workdir/eole/eole/bin/run/train.py", line 55, in train
    train_process(config, device_id=0)
  File "/workdir/eole/eole/train_single.py", line 140, in main
    config = update_config_with_checkpoint(config, checkpoint=checkpoint)
  File "/workdir/eole/eole/train_single.py", line 129, in update_config_with_checkpoint
    config = TrainConfig(**updated_config)
  File "/usr/local/lib/python3.10/dist-packages/pydantic/main.py", line 176, in __init__
    self.__pydantic_validator__.validate_python(data, self_instance=self)
pydantic_core._pydantic_core.ValidationError: 1 validation error for TrainConfig
model.transformer_lm.encoder
  Input should be None [type=none_required, input_value={'src_word_vec_size': 500, 'encoder_type': 'rnn'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.7/v/none_required

with this config;

# General settings
seed: 1234
share_vocab: true
save_data: "./finetune/llama3-8b-finetune"
src_vocab: "${EOLE_MODEL_DIR}/llama3-8b/vocab.txt" # size
src_vocab_size: 128256
tgt_vocab_size: 128256

overwrite: true

report_every: 10

n_sample: 0

tensorboard: true
tensorboard_log_dir: ./finetune/llama3-8b-finetune/logs/

# transforms config
transforms: [insert_mask_before_placeholder, onmt_tokenize, filtertoolong]

transforms_configs:
    onmt_tokenize:
        src_subword_type: bpe
        src_subword_model: "${EOLE_MODEL_DIR}/llama3-8b/bpe.model"
        tgt_subword_type: bpe
        tgt_subword_model: "${EOLE_MODEL_DIR}/llama3-8b/bpe.model"
        gpt2_pretok: true
    filtertoolong:
        src_seq_length: 2048
        tgt_seq_length: 2048

# datasets
data:
    train_dataset:
        path_src: "/data/train.txt"
    valid:
        path_src: "/data/valid.txt"

skip_empty_level: silent # silently ignore empty lines in the data

training:
    # GPU dispatching
    world_size: 1
    gpu_ranks: [0]

    zero_out_prompt_loss: true

    dropout_steps: [0]
    dropout: [0.0]
    attention_dropout: [0.0]
    # Batching
    bucket_size: 32768
    num_workers: 1
    batch_type: "tokens"
    batch_size: 2048
    valid_batch_size: 20482
    batch_size_multiple: 1

    # Optimization
    model_dtype: "fp16"
    apex_opt_level: ""
    optim: "fusedadam"
    learning_rate: 0.0001
    warmup_steps: 100
    decay_method: "none"
    #learning_rate_decay: 0.98
    #start_decay_steps: 100
    #decay_steps: 10
    adam_beta2: 0.998
    accum_count: [8]
    accum_steps: [0]
    max_grad_norm: 0
    label_smoothing: 0.0
    param_init: 0
    param_init_glorot: true
    normalization: "tokens"

    # folders
    train_from: "${EOLE_MODEL_DIR}/llama3-8b"
    model_path: "./finetune/llama3-8b-finetune"
    keep_checkpoint: 20
    save_checkpoint_steps: 500

    # 4/8bit
    quant_layers: ['w_1', 'w_2', 'w_3', 'linear_values', 'linear_query', 'linear_keys', 'final_linear']
    quant_type: "bnb_NF4"

    # LoRa
    lora_layers: ['linear_values', 'linear_query', 'linear_keys', 'final_linear']
    lora_rank: 2
    lora_dropout: 0.05
    lora_alpha: 8
    lora_embedding: false
francoishernandez commented 3 months ago

Mismatch between default config initialization and update from checkpoint. You can add new_config.pop("model") around here as a short term patch. https://github.com/eole-nlp/eole/blob/6d1e1be43a399a053db7b0ef7a6da58065091b78/eole/train_single.py#L117

Will try and push a better fix later.