seujung / KoBART-summarization

Summarization module based on KoBART
MIT License
196 stars 89 forks source link

KoBART summarization fine-tuning시 에러 발생 #23

Closed BrainNim closed 2 years ago

BrainNim commented 2 years ago

안녕하세요 좋은 모델 배포해주셔서 감사합니다.

KoBART summarization을 이용하기 위해 설치 후 fine tuning을 하기 위해 Read.me에 안내된 아래의 코드를 실행했습니다.

[use cpu]
python train.py  --gradient_clip_val 1.0 --max_epochs 50 --default_root_dir logs  --batch_size 4 --num_workers 4

하지만 Validation sanity check 과정에서 다음과 같은 에러가 발생하였습니다.

INFO:root:Namespace(accelerator=None, accumulate_grad_batches=1, amp_backend='native', amp_level='O2', auto_lr_find=False, auto_scale_batch_size=False, auto_select_gpus=False, batch_size=4, benchmark=False, check_val_every_n_epoch=1, checkpoint_callback=True, checkpoint_path=None, default_root_dir='logs', deterministic=False, distributed_backend=None, fast_dev_run=False, flush_logs_every_n_steps=100, gpus=None, gradient_clip_algorithm='norm', gradient_clip_val=1.0, limit_predict_batches=1.0, limit_test_batches=1.0, limit_train_batches=1.0, limit_val_batches=1.0, log_every_n_steps=50, log_gpu_memory=None, logger=True, lr=3e-05, max_epochs=50, max_len=512, max_steps=None, max_time=None, min_epochs=None, min_steps=None, model_path=None, move_metrics_to_cpu=False, multiple_trainloader_mode='max_size_cycle', num_nodes=1, num_processes=1, num_sanity_val_steps=2, num_workers=4, overfit_batches=0.0, plugins=None, precision=32, prepare_data_per_node=True, process_position=0, profiler=None, progress_bar_refresh_rate=None, reload_dataloaders_every_epoch=False, replace_sampler_ddp=True, resume_from_checkpoint=None, stochastic_weight_avg=False, sync_batchnorm=False, terminate_on_nan=False, test_file='data/test.tsv', tpu_cores=None, track_grad_norm=-1, train_file='data/train.tsv', truncated_bptt_steps=None, val_check_interval=1.0, warmup_ratio=0.1, weights_save_path=None, weights_summary='top')
using cached model
using cached model
using cached model
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
INFO:root:number of workers 4, data length 34242
INFO:root:num_train_steps : 107006
INFO:root:num_warmup_steps : 10700
2021-11-05 10:27:55.060417: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'cudart64_101.dll'; dlerror: cudart64_101.dll not found
2021-11-05 10:27:55.069132: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | BartForConditionalGeneration | 123 M
-------------------------------------------------------
123 M     Trainable params
0         Non-trainable params
123 M     Total params
495.440   Total estimated model params size (MB)
Validation sanity check:   0%|                                                                   | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train.py", line 233, in <module>
    trainer.fit(model, dm)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\trainer\trainer.py", line 460, in fit
    self._run(model)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\trainer\trainer.py", line 758, in _run
    self.dispatch()
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\trainer\trainer.py", line 799, in dispatch
    self.accelerator.start_training(self)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\accelerators\accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\trainer\trainer.py", line 809, in run_stage
    return self.run_train()
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\trainer\trainer.py", line 844, in run_train
    self.run_sanity_check(self.lightning_module)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\trainer\trainer.py", line 1112, in run_sanity_check
    self.run_evaluation()
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\trainer\trainer.py", line 967, in run_evaluation
    output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 174, in evaluation_step
    output = self.trainer.accelerator.validation_step(args)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\accelerators\accelerator.py", line 226, in validation_step
    return self.training_type_plugin.validation_step(*args)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py", line 161, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
  File "train.py", line 195, in validation_step
    outs = self(batch)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "train.py", line 185, in forward
    labels=inputs['labels'], return_dict=True)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\transformers\models\bart\modeling_bart.py", line 1295, in forward
    return_dict=return_dict,
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\transformers\models\bart\modeling_bart.py", line 1157, in forward
    return_dict=return_dict,
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\transformers\models\bart\modeling_bart.py", line 748, in forward
    inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\sparse.py", line 126, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\functional.py", line 1852, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.IntTensor instead (while checking arguments for embedding)

정상적으로 작동하게 하기 위해서는 어떻게 해야할까요??ㅜㅜ 감사합니다

bellhyeon commented 2 years ago

forward 할 때 input['labels'] -> input['labels'].long() 으로 바꿔주시면 해결 가능합니다.

BrainNim commented 2 years ago

안녕하세요, 답변주셔서 감사합니다. train.py의 forword함수의 해당 부분을 바꿔보았으나 모두 에러가 발생했습니다.ㅜㅜ

      return self.model(input_ids=inputs['input_ids'],
                        attention_mask=attention_mask,
                        decoder_input_ids=inputs['decoder_input_ids'],
                        decoder_attention_mask=decoder_attention_mask,
                        labels=inputs['labels'], return_dict=True)

이 부분의 labels=inputs['labels']를 labels=inputs['labels'].long(), labels=inputs['labels'].int(), labels=long(inputs['labels']), labels=int(inputs['labels'])를 시도해보았으나 모두 기존과 유사한 에러가 발생하였습니다.

Validation sanity check:   0%|                                                                   | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train.py", line 233, in <module>
    trainer.fit(model, dm)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\pytorch_lightning\trainer\trainer.py", line 460, in fit
    self._run(model)

(생략)

  File "train.py", line 195, in validation_step
    outs = self(batch)
  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "train.py", line 185, in forward
    labels=inputs['labels'].long(), return_dict=True)

(생략)

  File "C:\Users\Newrun\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.7_qbz5n2kfra8p0\LocalCache\local-packages\Python37\site-packages\torch\nn\functional.py", line 1852, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.IntTensor instead (while checking arguments for embedding)

어떻게 하면 해결할 수 있을까요?ㅠㅠㅠㅠ

seujung commented 2 years ago

@BrainNim 실행 환경 관련 이슈로 보입니다. 혹시 실행하신 환경을 적어주실 수 있으신지요?

BrainNim commented 2 years ago

Windows 10에서 CPU로 돌려보려고 합니다. pip install git+https://github.com/SKT-AI/KoBART#egg=kobart으로 다운받은 KoBart-summarization 디렉토리에서 cmd를 활용해 python train.py --gradient_clip_val 1.0 --max_epochs 50 --default_root_dir logs --batch_size 4 --num_workers 4를 명령하였습니다.

실행환경을 말씀드릴 때는 어떤걸 말씀드려야 하나요....??ㅠㅠ

python 3.7.9

Package                          Version
-------------------------------- -------------------
absl-py                          0.13.0
aiohttp                          3.8.0
aiosignal                        1.2.0
altair                           4.1.0
astor                            0.8.1
astroid                          2.4.1
astunparse                       1.6.3
async-timeout                    4.0.0
asynctest                        0.13.0
attrs                            19.3.0
backcall                         0.1.0
backports.zoneinfo               0.2.1
base58                           2.1.1
beautifulsoup4                   4.6.0
bleach                           3.1.5
blinker                          1.4
bs4                              0.0.1
cached-property                  1.5.2
cachetools                       4.2.2
certifi                          2019.6.16
chardet                          3.0.4
charset-normalizer               2.0.7
clang                            5.0
click                            7.1.2
colorama                         0.4.3
cycler                           0.11.0
Cython                           0.29.14
DateTime                         4.3
decorator                        4.4.2
defusedxml                       0.6.0
entrypoints                      0.3
et-xmlfile                       1.0.1
eunjeon                          0.4.0
filelock                         3.3.2
Flask                            1.1.2
flatbuffers                      1.12
frozenlist                       1.2.0
fsspec                           2021.10.1
future                           0.18.2
gast                             0.3.3
gdown                            4.2.0
gensim                           3.8.3
gitdb                            4.0.9
GitPython                        3.1.24
google-auth                      1.33.0
google-auth-oauthlib             0.4.4
google-pasta                     0.2.0
googlemaps                       3.1.1
grpcio                           1.34.0
h5py                             2.10.0
haversine                        2.3.0
huggingface-hub                  0.0.12
idna                             2.8
importlib-metadata               1.6.0
ipykernel                        5.2.1
ipython                          7.14.0
ipython-genutils                 0.2.0
ipywidgets                       7.5.1
isort                            4.3.21
itsdangerous                     1.1.0
jdcal                            1.4.1
jedi                             0.17.0
Jinja2                           2.11.2
joblib                           0.15.1
JPype1                           1.1.2
json5                            0.9.5
jsonschema                       3.2.0
jupyter                          1.0.0
jupyter-client                   6.1.3
jupyter-console                  6.1.0
jupyter-core                     4.6.3
jupyterlab                       2.2.8
jupyterlab-server                1.2.0
keras                            2.6.0
keras-bert                       0.88.0
keras-embed-sim                  0.9.0
keras-layer-normalization        0.15.0
keras-multi-head                 0.28.0
keras-nightly                    2.5.0.dev2021032900
keras-pos-embd                   0.12.0
keras-position-wise-feed-forward 0.7.0
Keras-Preprocessing              1.1.2
keras-radam                      0.15.0
keras-self-attention             0.50.0
keras-transformer                0.39.0
kiwisolver                       1.3.2
kobart                           0.4
konlpy                           0.5.2
lazy-object-proxy                1.4.3
lightgbm                         3.2.1
lxml                             4.6.3
Markdown                         3.3.4
MarkupSafe                       1.1.1
matplotlib                       3.4.3
mccabe                           0.6.1
mecab-python                     0.996-ko-0.9.2-msvc
mistune                          0.8.4
multidict                        5.2.0
nbconvert                        5.6.1
nbformat                         5.0.6
nltk                             3.5
notebook                         6.0.3
numpy                            1.19.5
oauthlib                         3.1.0
opencv-contrib-python            4.5.4.58
opencv-python                    4.5.4.58
openpyxl                         3.0.5
opt-einsum                       3.3.0
packaging                        21.2
pandas                           1.0.3
pandocfilters                    1.4.2
parso                            0.7.0
pickleshare                      0.7.5
Pillow                           8.4.0
pip                              21.3.1
prometheus-client                0.7.1
prompt-toolkit                   3.0.5
protobuf                         3.17.3
pyarrow                          6.0.0
pyasn1                           0.4.8
pyasn1-modules                   0.2.8
pydeck                           0.7.1
pyDeprecate                      0.3.0
Pygments                         2.6.1
pylint                           2.5.2
PyMySQL                          0.10.1
pyparsing                        2.4.7
pyproj                           3.0.0.post1
pyrsistent                       0.16.0
PySocks                          1.7.1
python-dateutil                  2.8.1
pytorch-lightning                1.3.8
pytz                             2020.1
pytz-deprecation-shim            0.1.0.post0
pywin32                          227
pywinpty                         0.5.7
PyYAML                           5.4.1
pyzmq                            19.0.1
qtconsole                        4.7.4
QtPy                             1.9.0
regex                            2021.3.17
requests                         2.22.0
requests-oauthlib                1.3.0
rsa                              4.7.2
sacremoses                       0.0.46
scikit-learn                     0.24.2
scipy                            1.4.1
selenium                         3.141.0
Send2Trash                       1.5.0
six                              1.15.0
sklearn                          0.0
smart-open                       4.0.1
smmap                            5.0.0
soupsieve                        1.9.3
streamlit                        0.72.0
tensorboard                      2.2.2
tensorboard-data-server          0.6.1
tensorboard-plugin-wit           1.8.0
tensorflow                       2.2.0
tensorflow-cpu                   2.5.0
tensorflow-estimator             2.2.0
termcolor                        1.1.0
terminado                        0.8.3
testpath                         0.4.4
threadpoolctl                    2.0.0
tokenizers                       0.10.3
toml                             0.10.1
toolz                            0.11.1
torch                            1.7.1
torchmetrics                     0.6.0
tornado                          6.0.4
tqdm                             4.51.0
traitlets                        4.3.3
transformers                     4.3.3
tweepy                           3.10.0
typed-ast                        1.4.1
typing-extensions                3.7.4.3
tzdata                           2021.5
tzlocal                          4.1
urllib3                          1.25.3
validators                       0.18.2
watchdog                         2.1.6
wcwidth                          0.1.9
webencodings                     0.5.1
Werkzeug                         1.0.1
wheel                            0.36.2
widgetsnbextension               3.5.1
wrapt                            1.12.1
xlrd                             1.2.0
XlsxWriter                       1.4.0
yarl                             1.7.2
zipp                             3.1.0
zope.interface                   5.2.0
seujung commented 2 years ago

pytorch version issue 로 보입니다. 해당 부분 관련 requirements.txt 에 업데이트 하였으니 참고 바랍니다. mac에서 cpu로 정상 학습되는 것을 확인 하였습니다.