ai-forever / ru-gpts

Russian GPT3 models.
Apache License 2.0
2.08k stars 442 forks source link

The model requires `num_beams`, although it is not needed in the example #105

Open LEv145 opened 1 year ago

LEv145 commented 1 year ago

Ubuntu 20.04 pytorch==1.11.0a0+17540c5c NVIDIA CUDA 11.6.0 TensorRT 8.2.3 transformers==4.26.1 apex https://github.com/NVIDIA/apex/commit/0c8400aa04f4279b1384ae0633e73d6daf9fead7 or (https://github.com/qywu/apex/commit/798a36cea73960a22ef3615ed3f28cd9dbc74931 with patch _amp_state.py) deepspeed==0.8.0 triton==1.0.0 timm==0.3.2

Code:

import os

os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "5000"
os.environ["USE_DEEPSPEED"] = "1"

from src.xl_wrapper import RuGPT3XL

gpt = RuGPT3XL.from_pretrained(
    "sberbank-ai/rugpt3xl",
    weights_path="/mnt/store/models/rugpt3xl/mp_rank_00_model_states.pt",
    seq_len=512,
)
gpt.generate(
    (
        "\u041a\u0442\u043e \u0431\u044b\u043b \u043f\u0440\u0435\u0437\u0438"
        "\u0434\u0435\u043d\u0442\u043e\u043c \u0421\u0428\u0410 \u0432 2020?"
    ),
    max_length=50,
    no_repeat_ngram_size=3,
    repetition_penalty=2.0,
)

Error:

  Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
    File "/opt/rugpts/src/xl_wrapper.py", line 224, in generate
      res = super().generate(
    File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
      return func(*args, **kwargs)
    File "/opt/conda/lib/python3.8/site-packages/transformers/generation/utils.py", line 1331, in generate
      (generation_config.num_beams > 1)
  TypeError: '>' not supported between instances of 'NoneType' and 'int'

I don't know what num_beams does and how to make it work, but I would be happy to help

Pip freeze ```py absl-py==1.0.0 alabaster==0.7.12 apex==0.1 appdirs==1.4.4 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1618968359944/work astunparse==1.6.3 attrs==21.4.0 audioread==2.1.9 Babel==2.9.1 backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work beautifulsoup4 @ file:///home/conda/feedstock_root/build_artifacts/beautifulsoup4_1631087867185/work black @ file:///home/conda/feedstock_root/build_artifacts/black-recipe_1643636307408/work bleach==4.1.0 blis @ file:///home/conda/feedstock_root/build_artifacts/cython-blis_1636053204017/work boto3==1.11.11 botocore==1.14.17 brotlipy @ file:///home/conda/feedstock_root/build_artifacts/brotlipy_1636012188166/work cachetools==5.0.0 catalogue @ file:///home/conda/feedstock_root/build_artifacts/catalogue_1638867392804/work certifi==2021.10.8 cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1636046063618/work chardet @ file:///home/conda/feedstock_root/build_artifacts/chardet_1635814844635/work charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1638815705608/work click @ file:///home/conda/feedstock_root/build_artifacts/click_1635822600067/work cloudpickle==2.0.0 codecov==2.1.12 colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1602866480661/work conda==4.11.0 conda-build==3.21.8 conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1636021700973/work coverage==6.3.1 cryptography @ file:///home/conda/feedstock_root/build_artifacts/cryptography_1639699280509/work cudf @ file:///rapids/cudf-21.12.0a0%2B293.g0930f712e6-cp38-cp38-linux_x86_64.whl cugraph @ file:///rapids/cugraph-21.12.0a0%2B95.g4b8c1330-cp38-cp38-linux_x86_64.whl cuml @ file:///rapids/cuml-21.12.0a0%2B116.g4ce5bd609-cp38-cp38-linux_x86_64.whl cupy-cuda115 @ file:///rapids/cupy_cuda115-9.6.0-cp38-cp38-manylinux1_x86_64.whl cycler==0.11.0 cymem @ file:///home/conda/feedstock_root/build_artifacts/cymem_1636053152744/work Cython==0.29.27 dask @ file:///rapids/dask-2021.11.2-py3-none-any.whl dask-cuda @ file:///rapids/dask_cuda-21.12.0-py3-none-any.whl dask-cudf @ file:///rapids/dask_cudf-21.12.0a0%2B293.g0930f712e6-py3-none-any.whl dataclasses @ file:///home/conda/feedstock_root/build_artifacts/dataclasses_1628958434797/work debugpy==1.5.1 decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work deepspeed==0.8.0 defusedxml==0.7.1 distributed @ file:///rapids/distributed-2021.11.2-py3-none-any.whl docutils==0.15.2 entrypoints==0.3 executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1633213722787/work expecttest==0.1.3 fastrlock==0.8 filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1641470428964/work flake8==3.7.9 Flask==2.0.3 flatbuffers==23.1.21 fonttools==4.29.1 fsspec==2022.1.0 future==0.18.2 gast==0.4.0 glob2==0.7 google-auth==2.6.0 google-auth-oauthlib==0.4.6 google-pasta==0.2.0 graphsurgeon @ file:///workspace/TensorRT-8.2.3.0/graphsurgeon/graphsurgeon-0.4.5-py2.py3-none-any.whl grpcio==1.43.0 h5py==3.8.0 HeapDict==1.0.1 hjson==3.1.0 huggingface-hub==0.12.0 hypothesis==4.50.8 idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1609836280497/work imagesize==1.3.0 importlib-metadata==4.11.1 importlib-resources==5.4.0 iniconfig==1.1.1 ipykernel==6.9.0 ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1642613634924/work ipython-genutils==0.2.0 itsdangerous==2.0.1 jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1637175084646/work Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1636510082894/work jmespath==0.10.0 joblib==1.1.0 json5==0.9.6 jsonschema==4.4.0 jupyter-client==7.1.2 jupyter-core==4.9.1 jupyter-tensorboard @ git+https://github.com/cliffwoolley/jupyter_tensorboard.git@ffa7e26138b82549453306e06b535a9ac36db17a jupyterlab==2.3.2 jupyterlab-pygments==0.1.2 jupyterlab-server==1.2.0 jupytext==1.13.7 keras==2.11.0 kiwisolver==1.3.2 langcodes @ file:///home/conda/feedstock_root/build_artifacts/langcodes_1636741340529/work libarchive-c @ file:///home/conda/feedstock_root/build_artifacts/python-libarchive-c_1643045750800/work libclang==15.0.6.1 librosa==0.9.0 llvmlite==0.36.0 lmdb==1.3.0 locket==0.2.1 Markdown==3.3.6 markdown-it-py==1.1.0 MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1635833572614/work matplotlib==3.5.1 matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1631080358261/work mccabe==0.6.1 mdit-py-plugins==0.3.0 mistune==0.8.4 mock @ file:///home/conda/feedstock_root/build_artifacts/mock_1635819534735/work msgpack==1.0.3 murmurhash @ file:///home/conda/feedstock_root/build_artifacts/murmurhash_1636019583024/work mypy-extensions @ file:///home/conda/feedstock_root/build_artifacts/mypy_extensions_1635839660470/work nbclient==0.5.11 nbconvert==6.4.2 nbformat==5.1.3 nest-asyncio==1.5.4 networkx==2.6.3 ninja==1.11.1 nltk==3.7 notebook==6.4.1 numba @ file:///home/conda/feedstock_root/build_artifacts/numba_1623568544775/work numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1643958805350/work nvidia-dali-cuda110==1.10.0 nvidia-pyindex==1.0.9 nvtx==0.2.4 oauthlib==3.2.0 onnx @ file:///opt/pytorch/pytorch/third_party/onnx opt-einsum==3.3.0 packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1637239678211/work pandas==1.3.5 pandocfilters==1.5.0 parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work partd==1.2.0 pathspec @ file:///home/conda/feedstock_root/build_artifacts/pathspec_1626613672358/work pathy @ file:///home/conda/feedstock_root/build_artifacts/pathy_1635227809952/work pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1602535608087/work pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work Pillow @ file:///tmp/pillow-simd pkginfo @ file:///home/conda/feedstock_root/build_artifacts/pkginfo_1638813452194/work platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1644222440849/work pluggy==1.0.0 polygraphy==0.33.0 pooch==1.6.0 preshed @ file:///home/conda/feedstock_root/build_artifacts/preshed_1636077712344/work prettytable==3.1.0 prometheus-client==0.13.1 prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1643362612956/work protobuf==3.19.4 psutil @ file:///home/conda/feedstock_root/build_artifacts/psutil_1640887117172/work ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work py==1.11.0 py-cpuinfo==9.0.0 pyarrow @ file:///rapids/pyarrow-5.0.0-cp38-cp38-linux_x86_64.whl pyasn1==0.4.8 pyasn1-modules==0.2.8 pybind11==2.9.1 pycocotools @ git+https://github.com/nvidia/cocoapi.git@142b17a358fdb5a31f9d5153d7a9f3f1cd385178#subdirectory=PythonAPI pycodestyle==2.5.0 pycosat @ file:///home/conda/feedstock_root/build_artifacts/pycosat_1636020377748/work pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work pydantic @ file:///home/conda/feedstock_root/build_artifacts/pydantic_1636021149719/work pydot==1.4.2 pyflakes==2.1.1 Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1641580240686/work pynvml==11.4.1 pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1633192417276/work pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1642753572664/work pyrsistent==0.18.1 PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1635862404924/work pytest==6.2.5 pytest-cov==3.0.0 pytest-pythonpath==0.7.4 python-dateutil==2.8.2 python-hostlist==1.21 python-nvd3==0.15.0 python-slugify==5.0.2 pytorch-quantization==2.1.2 pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1633452062248/work PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1636139793187/work pyzmq==22.3.0 regex==2020.1.8 requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1637771257551/work requests-oauthlib==1.3.1 resampy==0.2.2 revtok @ git+git://github.com/jekbradbury/revtok.git@f1998b72a941d1e5f9578a66dc1c20b01913caab rmm @ file:///rapids/rmm-21.12.0a0%2B31.g0acbd51-cp38-cp38-linux_x86_64.whl rsa==4.8 ruamel-yaml-conda @ file:///home/conda/feedstock_root/build_artifacts/ruamel_yaml_1636009157217/work s3transfer==0.3.7 sacremoses==0.0.47 scikit-learn @ file:///rapids/scikit_learn-0.24.0-cp38-cp38-manylinux2010_x86_64.whl scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy_1619561901336/work Send2Trash==1.8.0 sentencepiece==0.1.97 shellingham @ file:///home/conda/feedstock_root/build_artifacts/shellingham_1612179560728/work six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work smart-open @ file:///home/conda/feedstock_root/build_artifacts/smart_open_1630238320325/work snowballstemmer==2.2.0 sortedcontainers==2.4.0 SoundFile==0.10.3.post1 soupsieve @ file:///home/conda/feedstock_root/build_artifacts/soupsieve_1638550740809/work spacy @ file:///home/conda/feedstock_root/build_artifacts/spacy_1642167419405/work spacy-legacy @ file:///home/conda/feedstock_root/build_artifacts/spacy-legacy_1625687473390/work spacy-loggers @ file:///home/conda/feedstock_root/build_artifacts/spacy-loggers_1634809367310/work Sphinx==4.4.0 sphinx-glpi-theme==0.3 sphinx-rtd-theme==1.0.0 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 srsly @ file:///home/conda/feedstock_root/build_artifacts/srsly_1638879568141/work stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1642255706390/work tabulate==0.8.9 tblib==1.7.0 tensorboard==2.11.2 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorflow==2.11.0 tensorflow-estimator==2.11.0 tensorflow-io-gcs-filesystem==0.30.0 tensorrt @ file:///workspace/TensorRT-8.2.3.0/python/tensorrt-8.2.3.0-cp38-none-linux_x86_64.whl termcolor==2.2.0 terminado==0.13.1 testpath==0.5.0 text-unidecode==1.3 thinc @ file:///home/conda/feedstock_root/build_artifacts/thinc_1638980259098/work threadpoolctl==3.1.0 timm==0.3.2 tokenizers==0.13.2 toml==0.10.2 tomli @ file:///home/conda/feedstock_root/build_artifacts/tomli_1644342247877/work toolz==0.11.2 torch==1.11.0a0+17540c5 torch-tensorrt @ file:///opt/pytorch/torch_tensorrt/py/dist/torch_tensorrt-1.1.0a0-cp38-cp38-linux_x86_64.whl torchtext @ file:///opt/pytorch/text torchvision @ file:///opt/pytorch/vision tornado==6.1 tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1632160078689/work traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1635260543454/work transformers==4.26.1 treelite @ file:///rapids/treelite-2.1.0-py3-none-manylinux2014_x86_64.whl treelite-runtime @ file:///rapids/treelite_runtime-2.1.0-py3-none-manylinux2014_x86_64.whl triton==1.0.0 typed-ast @ file:///home/conda/feedstock_root/build_artifacts/typed-ast_1643045767561/work typer @ file:///home/conda/feedstock_root/build_artifacts/typer_1630326630489/work typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1638334978229/work ucx-py @ file:///rapids/ucx_py-0.21.0a0%2B37.gbfa0450-cp38-cp38-linux_x86_64.whl uff @ file:///workspace/TensorRT-8.2.3.0/uff/uff-0.6.9-py2.py3-none-any.whl urllib3==1.25.11 wasabi @ file:///home/conda/feedstock_root/build_artifacts/wasabi_1638865582891/work wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1600965781394/work webencodings==0.5.1 Werkzeug==2.0.3 wget==3.2 wrapt==1.14.1 xgboost @ file:///rapids/xgboost-1.5.0-cp38-cp38-linux_x86_64.whl zict==2.0.0 zipp==3.7.0 ```
TatianaShavrina commented 1 year ago

Hey @LEv145 , thank you for bringing that up!

The num_beams parameter refers to the beam search decoding strategy for the model: see HuggingFace explanation Try to pass in to the generation function as an argument, or stick to sampling or greedy generation

The parameters can be found in the generate function in xl_wrapper script

LEv145 commented 1 year ago

Hey @LEv145 , thank you for bringing that up!

The num_beams parameter refers to the beam search decoding strategy for the model: see HuggingFace explanation Try to pass in to the generation function as an argument, or stick to sampling or greedy generation

The parameters can be found in the generate function in xl_wrapper script

Thanks it works! But there is a problem when processing the result:

Load checkpoint from /mnt/store/models/rugpt3xl/mp_rank_00_model_states.pt
Model Loaded
Traceback (most recent call last):
  File "/mnt/store/tests/test_rugpt3xl.py", line 29, in <module>
    main()
  File "/mnt/store/tests/test_rugpt3xl.py", line 19, in main
    result = gpt.generate(
  File "/opt/ru-gpts/src/xl_wrapper.py", line 244, in generate
    return list(map(self.tokenizer.decode, res.tolist()))
AttributeError: 'NoneType' object has no attribute 'tolist'
Code ```py import os import sys sys.path.append("/opt/ru-gpts/") os.environ["USE_DEEPSPEED"] = "1" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "5000" from src.xl_wrapper import RuGPT3XL def main(): gpt = RuGPT3XL.from_pretrained( "sberbank-ai/rugpt3xl", weights_path="/mnt/store/models/rugpt3xl/mp_rank_00_model_states.pt", seq_len=512, ) result = gpt.generate( "Кто был президентом США в 2020? ", max_length=50, num_beams=5, early_stopping=True, ) print(result) if __name__ == "__main__": main() ```
sh0tcall3r commented 1 year ago

I have the same problem while generating text with the model. Firstly it requires num_beams and after it's set, AttributeError: 'NoneType' object has no attribute 'tolist' appears like in the post above. Please fix or provide comments on how to resolve it