huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.22k stars 1.16k forks source link

Training does not start #226

Closed agademic closed 1 year ago

agademic commented 1 year ago

Hi all,

first of all thank you for the awesome library and your work!

I was trying to replicate your SFT example from the script clm_finetune_peft_imdb.py. But it seems that the training gets stuck even before it started. No error codes, nothing. It's just stuck here: 0%| | 0/2748 [00:00<?, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the '__call__' method is faster than using a method to encode the text followed by a call to the 'pad' method to get a padded encoding.

This is the example with opt-125M but the same happens with get-neox-20b. I'm running the script on a A100.

Here are the libraries: transformers==4.27.1 (checked also with installing from source) peft==0.3.0.dev0 bitsandbytes==0.37.1 accelerate==0.17.0

Anything else I need to check?

Any hints are highly appreciated!

younesbelkada commented 1 year ago

Hello @agademic How are you running the script? Are you running the script over a single GPU ?

agademic commented 1 year ago

Hi @younesbelkada,

yes single GPU A100 40GB. And I'm running it with the accelerate launch command.

younesbelkada commented 1 year ago

Thanks, are you running the exact same script? Can you share the script you are using by any chance?

agademic commented 1 year ago

Here's the script. I changed a line according to #219 and passed the output_dir directly in the script.



from dataclasses import dataclass, field
from itertools import chain
from typing import Optional

import torch
import transformers
from datasets import load_dataset
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model, prepare_model_for_int8_training
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default="facebook/opt-125m",
        metadata={
            "help": (
                "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
            )
        },
    )

@dataclass
class DataTrainingArguments:
    dataset_name: Optional[str] = field(
        default="imdb", metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    block_size: Optional[int] = field(
        default=1024, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )

training_args = TrainingArguments(output_dir="test")

parser = HfArgumentParser((ModelArguments, DataTrainingArguments))

model_args, data_args = parser.parse_args_into_dataclasses()

model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    load_in_8bit=True,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
# ### Prepare model for training
#
# Some pre-processing needs to be done before training such an int8 model using `peft`, therefore let's import an utiliy function `prepare_model_for_int8_training` that will:
# - Cast the layer norm in `float32` for stability purposes
# - Add a `forward_hook` to the input embedding layer to enable gradient computation of the input hidden states
# - Enable gradient checkpointing for more memory-efficient training
# - Cast the output logits in `float32` for smoother sampling during the sampling procedure

if "gpt-neox" in model_args.model_name_or_path:
    model = prepare_model_for_int8_training(
        model, layer_norm_names=[]
    )
else:
    model = prepare_model_for_int8_training(model, layer_norm_names=[])

# ### Apply LoRA
#
# Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`.
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

target_modules = None
if "gpt-neox" in model_args.model_name_or_path:
    target_modules = ["query_key_value", "xxx"]  # workaround to use 8bit training on this model
config = LoraConfig(
    r=16, lora_alpha=32, target_modules=target_modules, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

block_size = data_args.block_size

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

# ### Training
data = load_dataset("imdb")
columns = data["train"].features
data = data.map(lambda samples: tokenizer(samples["text"]), batched=True, remove_columns=columns)
data = data.map(group_texts, batched=True)

model.gradient_checkpointing_enable()
trainer = transformers.Trainer(
    model=model,
    train_dataset=data["train"],
    args=training_args,
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()```
younesbelkada commented 1 year ago

Weird, it works on my T4, here is the output of pip freeze:

absl-py==1.4.0
accelerate==0.17.1
aiofiles==23.1.0
aiohttp==3.8.3
aiosignal==1.3.1
alembic==1.9.2
altair==4.2.2
anyio==3.6.2
anykeystore==0.2
appdirs==1.4.4
APScheduler==3.9.1.post1
arrow==1.2.3
asttokens==2.2.1
astunparse==1.6.3
async-timeout==4.0.2
attrs==22.2.0
audioread==3.0.0
Babel==2.11.0
backcall==0.2.0
backoff==1.11.1
beautifulsoup4==4.11.1
binaryornot==0.4.4
bitsandbytes==0.37.1
black==23.1.0
bleach==6.0.0
boto3==1.26.59
botocore==1.29.59
braceexpand==0.1.7
Brotli==1.0.9
cachetools==5.3.0
certifi @ file:///croot/certifi_1671487769961/work/certifi
cffi==1.15.1
chardet==5.1.0
charset-normalizer==2.1.1
chex==0.1.5
click==8.1.3
clldutils==3.19.0
cmaes==0.9.1
cmake==3.26.0
codecarbon==1.2.0
colorama==0.4.6
coloredlogs==15.0.1
colorlog==6.7.0
commonmark==0.9.1
contourpy==1.0.7
cookiecutter==1.7.3
cryptacular==1.6.2
cryptography==39.0.0
csvw==3.1.3
cycler==0.11.0
dash==2.8.0
dash-bootstrap-components==1.3.1
dash-core-components==2.0.0
dash-html-components==2.0.0
dash-table==5.0.0
datasets==2.9.0
decorator==5.1.1
decord==0.6.0
defusedxml==0.7.1
diffusers==0.12.1
dill==0.3.4
distlib==0.3.6
dlinfo==1.2.1
dm-tree==0.1.8
docker==4.4.4
docker-pycreds==0.4.0
docutils==0.19
entrypoints==0.4
evaluate==0.4.0
exceptiongroup==1.1.0
execnet==1.9.0
executing==1.2.0
faiss-cpu==1.7.3
fastapi==0.92.0
fastjsonschema==2.16.2
ffmpy==0.3.0
filelock==3.9.0
fire==0.5.0
flake8==6.0.0
Flask==2.2.2
flatbuffers==2.0.7
flax==0.5.3
fonttools==4.38.0
frozenlist==1.3.3
fsspec==2023.1.0
ftfy==6.1.1
fugashi==1.2.1
gast==0.4.0
gdown==4.6.0
gitdb==4.0.10
GitPython==3.1.18
google-auth==2.16.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
gql==3.4.0
gradio==3.19.1
graphql-core==3.2.3
greenlet==2.0.2
grpcio==1.51.1
h11==0.14.0
h5py==3.8.0
hf-doc-builder @ git+https://github.com/huggingface/doc-builder@aa0b37d9611d89f95108fb9a1c46e83c813ab4ce
hjson==3.1.0
httpcore==0.16.3
httpx==0.23.3
huggingface-hub==0.12.0
humanfriendly==10.0
hupper==1.11
hypothesis==6.65.2
idna==3.4
importlib-metadata==6.0.0
inflate64==0.3.1
iniconfig==2.0.0
ipadic==1.0.0
ipython==8.9.0
isodate==0.6.1
isort==5.12.0
itsdangerous==2.1.2
jaraco.classes==3.2.3
jax==0.3.6
jaxlib==0.3.5
jedi==0.18.2
jeepney==0.8.0
Jinja2==3.1.2
jinja2-time==0.2.0
jmespath==1.0.1
joblib==1.2.0
jsonschema==4.17.3
jupyter_core==5.1.5
kenlm==0.1
keras==2.11.0
keras-nlp==0.4.0
keyring==23.13.1
kiwisolver==1.4.4
kubernetes==12.0.1
language-tags==1.2.0
libclang==15.0.6.1
librosa==0.9.2
linkify-it-py==2.0.0
lit==15.0.7
llvmlite==0.39.1
loralib==0.1.1
lxml==4.9.2
Mako==1.2.4
Markdown==3.4.1
markdown-it-py==2.2.0
MarkupSafe==2.1.2
matplotlib==3.6.3
matplotlib-inline==0.1.6
mccabe==0.7.0
mdit-py-plugins==0.3.3
mdurl==0.1.2
more-itertools==9.1.0
mpmath==1.2.1
msgpack==1.0.4
multidict==6.0.4
multiprocess==0.70.12.2
multivolumefile==0.2.3
mypy-extensions==0.4.3
nbformat==5.7.3
networkx==3.0
ninja==1.11.1
nltk==3.8.1
numba==0.56.4
numpy==1.23.5
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
oauthlib==3.2.2
onnx==1.12.0
onnxconverter-common==1.13.0
opt-einsum==3.3.0
optax==0.1.4
-e git+https://github.com/younesbelkada/optimum.git@6bef6baa8867b34161296d6112297efd8e9f9f78#egg=optimum
optuna==3.1.0
orjson==3.8.6
packaging==23.0
pandas==1.5.3
parameterized==0.8.1
parso==0.8.3
PasteDeploy==3.0.1
pathspec==0.11.0
pathtools==0.1.2
pbkdf2==1.3
peft==0.2.0
pexpect==4.8.0
phonemizer==3.2.1
pickleshare==0.7.5
Pillow==9.4.0
Pint==0.16.1
pkginfo==1.9.6
plac==1.3.5
plaster==1.1.2
plaster-pastedeploy==1.0.1
platformdirs==2.6.2
plotly==5.13.0
pluggy==1.0.0
pooch==1.6.0
portalocker==2.0.0
poyo==0.5.0
prompt-toolkit==3.0.36
protobuf==3.19.6
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
py7zr==0.20.4
pyarrow==11.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybcj==1.0.1
pycodestyle==2.10.0
pycparser==2.21
pycryptodome==3.17
pycryptodomex==3.17
pyctcdecode==0.5.0
pydantic==1.10.4
pydub==0.25.1
pyflakes==3.0.1
Pygments==2.14.0
pygtrie==2.5.0
pylatexenc==2.10
pynvml==11.4.1
pyOpenSSL==23.0.0
pyparsing==3.0.9
pypng==0.20220715.0
pyppmd==1.0.0
pyramid==2.0.1
pyramid-mailer==0.15.1
pyrsistent==0.19.3
PySocks==1.7.1
pytest==7.2.1
pytest-subtests==0.9.0
pytest-timeout==2.1.0
pytest-xdist==3.1.0
python-dateutil==2.8.2
python-multipart==0.0.5
python-slugify==8.0.0
python3-openid==3.2.0
pytz==2022.7.1
pytz-deprecation-shim==0.1.0.post0
PyYAML==5.4.1
pyzstd==0.15.4
ray==2.2.0
rdflib==6.2.0
readme-renderer==37.3
regex==2022.10.31
repoze.sendmail==4.4.1
requests==2.28.2
requests-oauthlib==1.3.1
requests-toolbelt==0.10.1
resampy==0.3.0
responses==0.18.0
rfc3986==1.5.0
rhoknp==1.1.2
rich==13.3.1
rjieba==0.1.11
rouge-score==0.1.2
rsa==4.9
ruff==0.0.254
s3transfer==0.6.0
sacrebleu==1.5.1
sacremoses==0.0.53
safetensors==0.3.0
scikit-learn==1.2.1
scipy==1.10.0
seaborn==0.12.2
SecretStorage==3.3.3
segments==2.2.1
sentencepiece==0.1.97
sentry-sdk==1.14.0
setproctitle==1.3.2
sigopt==8.6.3
six==1.16.0
smmap==5.0.0
sniffio==1.3.0
sortedcontainers==2.4.0
soundfile==0.11.0
soupsieve==2.3.2.post1
SQLAlchemy==2.0.0
stack-data==0.6.2
starlette==0.25.0
SudachiDict-core==20230110
SudachiPy==0.6.6
sympy==1.11.1
tabulate==0.9.0
tenacity==8.1.0
tensorboard==2.11.2
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.5.1
tensorflow==2.11.0
tensorflow-estimator==2.11.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.30.0
tensorflow-text==2.11.0
tensorstore==0.1.30
termcolor==2.2.0
text-unidecode==1.3
texttable==1.6.7
tf2onnx==1.13.0
threadpoolctl==3.1.0
timeout-decorator==0.5.0
timm==0.6.12
tokenizers==0.12.1
tomli==2.0.1
toolz==0.12.0
torch==2.0.0
torchaudio==2.0.1
torchlibrosa==0.0.9
torchvision==0.15.1
tqdm==4.64.1
traitlets==5.8.1
transaction==3.0.1
-e git+https://github.com/younesbelkada/transformers.git@c965e2477163c6f5923ad041687f1ff9c5a323f8#egg=transformers
translationstring==1.4
triton==2.0.0
-e git+https://github.com/lvwerra/trl.git@24627e9c89fde0fe7c5652ab53cd7523c4c49d58#egg=trl
twine==4.0.2
typing_extensions==4.4.0
tzdata==2022.7
tzlocal==4.2
uc-micro-py==1.0.1
unidic==1.1.0
unidic-lite==1.0.8
uritemplate==4.1.1
urllib3==1.26.14
uvicorn==0.20.0
velruse==1.1.1
venusian==3.0.0
virtualenv==20.17.1
wandb==0.13.9
wasabi==0.10.1
wcwidth==0.2.6
webdataset==0.2.31
webencodings==0.5.1
WebOb==1.8.7
websocket-client==1.5.0
websockets==10.4
Werkzeug==2.2.2
wget==3.2
wrapt==1.14.1
WTForms==3.0.1
wtforms-recaptcha==0.3.2
xxhash==3.2.0
yarl==1.8.2
zipp==3.12.0
zope.deprecation==4.4.0
zope.interface==5.5.2
zope.sqlalchemy==1.6

Can you maybe upgrade accelerate? Also I am using peft==0.2.0

agademic commented 1 year ago

Unfortunately, upgrading accelerate to 0.17.1 and downgrading peft to 0.2.0 does not help. Still stuck at the same point.

lvwerra commented 1 year ago

What version of trl are you using?

agademic commented 1 year ago

trl==0.4.2.dev0

agademic commented 1 year ago

To make sure, I just created a new Instance on GCP with a T4. Still stuck.

EDIT: I think the issue may be GCP related. Closing for now and will update when I found a solution.