learning-at-home / lean_transformer

Memory-efficient transformer. Work in progress.
MIT License
19 stars 3 forks source link

Research diary: reversibles #11

Open justheuristic opened 2 years ago

justheuristic commented 2 years ago

This issue may or may not contain my notes about implementing and training reversble models. From correspondence with @TimDettmers, @yhn112 , @borzunov, @mryab

Why? Reversible models are one of the few ways to fit large transformers into low-end DL rigs (single gpu, 8-12gb vram, 16-32gb ram). The alternatives are nested checkpointing [slower], checkpoint quantization [may affect convergence, still more memory], or checkpoint offloading [ram is already used up for master params and optimizer]. Hence, reversibles are the most memory-efficient training strategy if they can match the quality of regular models.

As of c076ba7 , we support reversible=True which triggers reversible transformer as it was defined in Reformer. However, this is not the only possible way. Existing alternatives are:

source: running average coupling ``` class MeanCoupling: def __init__(self, layer_index: int): assert layer_index > 0, "layer with index zero must be applied before reversible (e.g. embeddings)" self.layer_index = layer_index def forward(self, other_stream: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor: i = self.layer_index return other_stream * (i / (i + 1)) + fn_out * (1 / (i + 1)) def inverse(self, forward_output: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor: i = self.layer_index return (forward_output - fn_out * (1 / (i + 1))) / (i / (i + 1)) ```

Furthermore, it is unclear how best to use reversible's two inputs and two outputs with transformers. If we cannot afford to double the layer sizes, this raises the following questions:

Finally, there are implementation hacks that can affect the training throughput and memory requirements:

justheuristic commented 2 years ago

Research setting:

Limitations:

Results:

tl;dr reformer is still the best image

Validation loss closely tracks the training loss. image

Gradient norm supports the hypothesis for numerical instability of momentum-reversible models image

Details

Model config ```json { "model_type": "lean_gpt", "architectures": [ "LeanGPTModel" ], "num_hidden_layers": 24, "num_hidden_groups": 24, "num_inner_groups": 1, "hidden_size": 1024, "embedding_size": 1024, "intermediate_size": 4096, "num_attention_heads": 16, "vocab_size": 50308, "hidden_act": "gelu_fused", "position_embedding_type": "rotary", "tie_word_embeddings": false, "reversible": YOUR_OPTION_HERE, "hidden_dropout_prob": 0, "attention_probs_dropout_prob": 0, "layer_norm_eps": 1e-12, "pad_token_id": 1, "bos_token_id": 0, "eos_token_id": 2 } ```
Training code Using a slightly modified fairseq training: https://github.com/justheuristic/junk/tree/fairseq #9f5bff306ee93780e0c9483162dd24e244403919 . The only difference = it supports training with LeanTransformer. Was validated to match the learning curve of fairseq transformer. ```bash PYTHONPATH=`pwd`:$PYTHONPATH python fairseq_cli/train.py \ $INPUT_PATH/data-bin/openwebtext --task language_modeling --arch lean_lm --hf-model-config $SOURCE_CODE_PATH/model_config.json \ --max-tokens 32768 --update-freq 4 --max-update 50000 --tokens-per-sample 2048 --sample-break-mode none \ --ddp-backend pytorch_ddp --distributed-world-size $NUM_GPUS --seed 4 \ --amp --fp16-no-flatten-grads --min-loss-scale 1e-10 --fp16-scale-window 250 \ --lr-scheduler cosine --lr 0.0003 --warmup-init-lr 0.0 --warmup-updates 5000 \ --optimizer adam --weight-decay 0.1 --clip-norm 1.0 --adam-betas "(0.9, 0.95)" --adam-eps 1e-08 \ --save-dir $SNAPSHOT_PATH --save-interval-updates 1000 --keep-best-checkpoints 1 --no-epoch-checkpoints --keep-interval-updates 2 \ --valid-subset valid,valid_1b,valid_lambada,valid_ccnews,valid_wiki,valid_wiki2,valid_ptb --validate-interval-updates 1000 \ --log-format simple --log-interval 50 --wandb-project $WANDB_PROJECT ```
Libraries & versions ```bash #!/usr/bin/env bash set -euxo pipefail ############################################################################ # core libraries ############################################################################ apt-get update --allow-unauthenticated --allow-insecure-repositories apt-get install -y --no-install-recommends \ build-essential \ g++ gdb subversion \ software-properties-common apt-get install -y --no-install-recommends \ wget curl vim nano ssh git libssl-dev apt-get remove -y swig || true apt-get install -y --no-install-recommends libstdc++6 apt-get install -y --no-install-recommends swig3.0 ln -s /usr/bin/swig3.0 /usr/bin/swig ############################################################################ # install anaconda (because native python stopped working ############################################################################ wget https://repo.anaconda.com/archive/Anaconda3-2021.11-Linux-x86_64.sh bash Anaconda3-2021.11-Linux-x86_64.sh -b -p /anaconda3 source /anaconda3/bin/activate ############################################################################ # common python libraries (project specfic libs these are installed later) ############################################################################ conda update -y conda conda install -y python=3.8.12 --strict-channel-priority conda install -y numpy scipy cython pandas h5py numba pip install --upgrade setuptools # common + devops pip install \ PyYAML==5.4.1 \ Pillow==8.3.0 \ docopt==0.6.2 \ typer==0.3.2 \ black==21.6b0 \ bokeh==2.4.0dev1 \ isort==5.9.1 \ icecream==2.1.1 \ flake8==3.9.2 \ uvloop==0.15.2 \ packaging==19.0 \ msgpack==0.5.6 \ sortedcontainers==2.4.0 \ configargparse==1.2.3 \ tqdm==4.48.2 \ termcolor==1.0.0 # common data science libs pip install \ ninja==1.10.0.post1 \ tensorboardX==2.4 \ wandb==0.10.33 \ matplotlib==3.4.2 \ seaborn==0.11.1 \ holoviews==1.14.4 \ plotly==5.1.0 \ jupyterlab==3.0.16 # pytorch utils conda install -y cudatoolkit=11.3 -c pytorch pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 pip install https://github.com/huggingface/transformers/archive/3dc82427166239e2764196c07fa4c5dcc25b1590.zip # 4.18.dev0 pip install datasets==2.0.0 pip install \ torch_optimizer==0.1.0 \ revlib==1.7.0 \ bitsandbytes-cuda113==0.26.0 \ pytorch-lightning==1.3.8 \ triton==1.0.0 \ einops==0.3.2 \ libzero==0.0.5 # domain-specific ML libs pip install \ opencv-python==4.4.0.42 \ albumentations==1.0.0 \ scikit-image==0.17.2 \ lmdb==1.2.1 \ librosa==0.7.0 \ sentencepiece==0.1.96 \ nltk==3.6.2 \ gensim==4.0.1 \ sacrebleu==1.5.1 \ sacremoses==0.0.45 \ subword-nmt==0.3.7 \ youtokentome==1.0.6 pip uninstall -y enum34 ############################################################################ # Set locale ############################################################################ locale-gen ru_RU.UTF-8 update-locale ############################################################################ # Clean ############################################################################ apt-get autoremove apt-get clean apt-get autoclean rm -rf /var/lib/apt/lists/* rm -rf /tmp/* rm -rf /.cache rm -rf /var/cache/apt/*.bin find /var/log -iname '*.gz' -delete find /var/log -iname '*.1' -delete ########################################################################### # project-specific libraries (aka YOUR CODE HERE) ########################################################################### # hivemind dependencies pip install \ prefetch_generator>=1.0.1 \ grpcio>=1.33.2 \ grpcio-tools>=1.33.2 \ multiaddr>=0.0.9 \ pymultihash>=0.8.2 \ cryptography>=3.4.6 \ pydantic>=1.8.1 \ whatsmyip pip install razdel==0.5.0 # golang wget https://golang.org/dl/go1.16.4.linux-amd64.tar.gz rm -rf /usr/local/go && tar -C /usr/local -xzf go1.16.4.linux-amd64.tar.gz export PATH=$PATH:/usr/local/go/bin pip install omegaconf==2.0.5 antlr4-python3-runtime==4.8 hydra-core==1.0.7 ```