syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

Can't Resume Training from Checkpoint #17

Closed infosechoudini closed 11 months ago

infosechoudini commented 12 months ago

Hey,

So I'm training the model using Huggingface Trainer. If the trainer exits for any reason and I resume from checkpoint, the model no longer learns anymore. I'm using the trian.py as is and executing training with Accelerate and Deepspeed.

train.py (ignore all the extra imports, I was trying to load the model in 8bit)

from dataclasses import dataclass

from transformers import (Trainer, TrainingArguments, AutoTokenizer, HfArgumentParser,
                          DataCollatorForLanguageModeling, BitsAndBytesConfig, AutoModel)
from datasets import load_dataset

from retnet.modeling_retnet import RetNetModelWithLMHead
from retnet.configuration_retnet import load_config_from_yaml
import torch
from retnet.bits_and_bytes import get_keys_to_not_convert, replace_with_bnb_linear
from accelerate import infer_auto_device_map
from accelerate.utils import (
    check_tied_parameters_on_same_device,
    find_tied_parameters,
    get_balanced_memory,
    get_max_memory,
    load_offloaded_weights,
    offload_weight,
    save_offload_index,
    set_module_tensor_to_device,
)
from transformers.utils import logging
from retnet.training_data.data_utils import *

logger = logging.get_logger(__name__)

#logging.set_verbosity_info()

@dataclass
class MyArgs:
    model_size: str = '2.7b'
    dataset_name: str = 'sst2'
    text_col: str = 'sentence'
    max_length: int = 512
    checkpoint: bool = False

def init_base_model(args):
    config = load_config_from_yaml(f"configs/retnet-{args.model_size}.yml")

    model = RetNetModelWithLMHead(config)

    return model

def get_base_model():    

    model = RetNetModelWithLMHead.from_pretrained('./checkpoints')

    return model

def main():
    parser = HfArgumentParser((TrainingArguments, MyArgs))
    train_args, args = parser.parse_args_into_dataclasses()

    tokenizer = AutoTokenizer.from_pretrained('./checkpoints', use_fast=True)

    dataset_split = get_pretrain_arvix(128, tokenizer, 0, 100)

    model = init_base_model(args)

    trainer = Trainer(model=model,
                      args=train_args,
                      train_dataset=dataset_split['train'],
                      eval_dataset=dataset_split['test'],
                      tokenizer=tokenizer,
                      data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False))

    if train_args.do_train:
        with torch.autocast('cuda'):
            trainer.train(resume_from_checkpoint=args.checkpoint)
            trainer.save_model(output_dir='./frenos_base')
    if train_args.do_eval:
        trainer.evaluate()

if __name__ == "__main__":
    main()

ds_config.json

{
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": "auto",
        "contiguous_gradients": true
    },
    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

train_renet.sh

deepspeed --num_gpus 7 train.py \
    --model_size 300m \
    --output_dir checkpoints \
    --do_train --checkpoint \
    --warmup_steps 375 \
    --lr_scheduler_type linear \
    --prediction_loss_only \
    --remove_unused_columns False \
    --learning_rate 6e-4 \
    --weight_decay 0.05 \ ##I changed weight decay to .05 to match the paper. 
    --num_train_epochs 4 \
    --logging_steps 10 \
    --eval_steps 1000 \
    --save_steps 1000 \
    --per_device_train_batch_size 12 \
    --per_device_eval_batch_size 12 \
    --deepspeed ../../ds_config.json 
infosechoudini commented 12 months ago

Also, i'm pretraining using the datasets from https://huggingface.co/datasets/togethercomputer/Long-Data-Collections