hankcs / HanLP

Natural Language Processing for the next decade. Tokenization, Part-of-Speech Tagging, Named Entity Recognition, Syntactic & Semantic Dependency Parsing, Document Classification
https://hanlp.hankcs.com/en/
Apache License 2.0
33.73k stars 10.08k forks source link

open_base.py训练有bug , mat1 and mat2 shapes cannot be multiplied #1703

Closed Yumeka999 closed 2 years ago

Yumeka999 commented 2 years ago

Describe the bug open_base.py训练有bug

mat1 and mat2 shapes cannot be multiplied (800x256 and 768x1536)

Code to reproduce the issue

# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-12-03 14:24

import os
from hanlp.common.dataset import SortingSamplerBuilder
from hanlp.common.transform import NormalizeCharacter
from hanlp.components.mtl.multi_task_learning import MultiTaskLearning
from hanlp.components.mtl.tasks.constituency import CRFConstituencyParsing
from hanlp.components.mtl.tasks.dep import BiaffineDependencyParsing
from hanlp.components.mtl.tasks.ner.tag_ner import TaggingNamedEntityRecognition
from hanlp.components.mtl.tasks.pos import TransformerTagging
from hanlp.components.mtl.tasks.sdp import BiaffineSemanticDependencyParsing
from hanlp.components.mtl.tasks.srl.bio_srl import SpanBIOSemanticRoleLabeling
from hanlp.components.mtl.tasks.tok.tag_tok import TaggingTokenization
from hanlp.datasets.ner.msra import MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TRAIN, MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_DEV, \
    MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TEST
from hanlp.datasets.parsing.ctb8 import CTB8_POS_TRAIN, CTB8_POS_DEV, CTB8_POS_TEST, CTB8_SD330_TEST, CTB8_SD330_DEV, \
    CTB8_SD330_TRAIN, CTB8_CWS_TRAIN, CTB8_CWS_DEV, CTB8_CWS_TEST, CTB8_BRACKET_LINE_NOEC_TRAIN, \
    CTB8_BRACKET_LINE_NOEC_DEV, CTB8_BRACKET_LINE_NOEC_TEST
from hanlp.datasets.parsing.semeval16 import SEMEVAL2016_TEXT_TRAIN_CONLLU, SEMEVAL2016_TEXT_TEST_CONLLU, \
    SEMEVAL2016_TEXT_DEV_CONLLU
# from hanlp.datasets.srl.ontonotes5.chinese import ONTONOTES5_CONLL12_CHINESE_TEST, ONTONOTES5_CONLL12_CHINESE_DEV, \
#     ONTONOTES5_CONLL12_CHINESE_TRAIN
from hanlp.layers.embeddings.contextual_word_embedding import ContextualWordEmbedding
from hanlp.layers.transformers.relative_transformer import RelativeTransformerEncoder
from hanlp.utils.lang.zh.char_table import HANLP_CHAR_TABLE_JSON
from hanlp.utils.log_util import cprint

root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))

def cdroot():
    """
    cd to project root, so models are saved in the root folder
    """
    os.chdir(root)

n_batch_size = 8

tasks = {
    'tok': TaggingTokenization(  # 分词
        CTB8_CWS_TRAIN,
        CTB8_CWS_DEV,
        CTB8_CWS_TEST,
        SortingSamplerBuilder(batch_size=n_batch_size),
        max_seq_len=510,
        hard_constraint=True,
        char_level=True,
        tagging_scheme='BMES',
        lr=1e-3,
        transform=NormalizeCharacter(HANLP_CHAR_TABLE_JSON, 'token'),
    ),
    # 'pos': TransformerTagging(  # 词性
    #     CTB8_POS_TRAIN,
    #     CTB8_POS_DEV,
    #     CTB8_POS_TEST,
    #     SortingSamplerBuilder(batch_size=n_batch_size),
    #     hard_constraint=True,
    #     max_seq_len=510,
    #     char_level=True,
    #     dependencies='tok',
    #     lr=1e-3,
    # ),
    'ner': TaggingNamedEntityRecognition(  # 实体  mat multi mat erro
        MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TRAIN,
        MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_DEV,
        MSRA_NER_TOKEN_LEVEL_SHORT_IOBES_TEST,
        SortingSamplerBuilder(batch_size=n_batch_size),
        lr=1e-3,
        secondary_encoder=RelativeTransformerEncoder(768, k_as_x=True),
        dependencies='tok',
    ),
    # 'srl': SpanBIOSemanticRoleLabeling(  # 依存句法  download error
    #     ONTONOTES5_CONLL12_CHINESE_TRAIN,
    #     ONTONOTES5_CONLL12_CHINESE_DEV,
    #     ONTONOTES5_CONLL12_CHINESE_TEST,
    #     SortingSamplerBuilder(batch_size=n_batch_size, batch_max_tokens=2048),
    #     lr=1e-3,
    #     crf=True,
    #     dependencies='tok',
    # ),
    # 'dep': BiaffineDependencyParsing(  # 成分句法
    #     CTB8_SD330_TRAIN,
    #     CTB8_SD330_DEV,
    #     CTB8_SD330_TEST,
    #     SortingSamplerBuilder(batch_size=n_batch_size),
    #     lr=1e-3,
    #     tree=True,
    #     punct=True,
    #     dependencies='tok',
    # ),
    # 'sdp': BiaffineSemanticDependencyParsing(  # 语义依存
    #     SEMEVAL2016_TEXT_TRAIN_CONLLU,
    #     SEMEVAL2016_TEXT_DEV_CONLLU,
    #     SEMEVAL2016_TEXT_TEST_CONLLU,
    #     SortingSamplerBuilder(batch_size=n_batch_size),
    #     lr=1e-3,
    #     apply_constraint=True,
    #     punct=True,
    #     dependencies='tok',
    # ),
    # 'con': CRFConstituencyParsing(  # 语义角色 memory out
    #     CTB8_BRACKET_LINE_NOEC_TRAIN,
    #     CTB8_BRACKET_LINE_NOEC_DEV,
    #     CTB8_BRACKET_LINE_NOEC_TEST,
    #     SortingSamplerBuilder(batch_size=n_batch_size),
    #     lr=1e-3,
    #     dependencies='tok',
    # )
}

mtl = MultiTaskLearning()
save_dir = 'data/model/mtl/open_tok_pos_ner_srl_dep_sdp_con_electra_small'
# save_dir = 'data/model/mtl/open_tok_pos_ner_srl_dep_sdp_con_electra_base'
# save_dir = 'data/model/mtl/open_tok_pos_ner_bert_base'
mtl.fit(
    ContextualWordEmbedding('token',
                            # "bert-base-chinese",
                            "hfl/chinese-electra-180g-small-discriminator",
                            # "hfl/chinese-electra-180g-base-discriminator",
                            average_subwords=True,
                            max_sequence_length=510,
                            word_dropout=.1),
    tasks,
    save_dir,
    1,  # 30
    lr=1e-3,
    encoder_lr=5e-5,
    grad_norm=1,
    gradient_accumulation=2,
    eval_trn=False,
)
cprint(f'Model saved in [cyan]{save_dir}[/cyan]')
mtl.load(save_dir)
for k, v in tasks.items():
    v.trn = tasks[k].trn
    v.dev = tasks[k].dev
    v.tst = tasks[k].tst
metric, *_ = mtl.evaluate(save_dir)
for k, v in tasks.items():
    print(metric[k], end=' ')
print()
print(mtl('华纳音乐旗下的新垣结衣在12月21日于日本武道馆举办歌手出道活动'))

Describe the current behavior python open_base.py 运行报错 RuntimeError: mat1 and mat2 shapes cannot be multiplied (800x256 and 768x1536)

Expected behavior 程序正常训练

System information

Other info / logs Using GPUs: [0] Epoch 1 / 1: 1/24919 loss: 0.7001 ETA: 43 m 14 sTraceback (most recent call last): File "hanlp_train.py", line 135, in eval_trn=False, File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/components/mtl/multi_task_learning.py", line 644, in fit tasks) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/common/torch_component.py", line 295, in fit overwrite=True)) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/components/mtl/multi_task_learning.py", line 287, in execute_training_loop self.config) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/components/mtl/multi_task_learning.py", line 346, in fit_dataloader outputdict, = self.feed_batch(batch, task_name) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/components/mtl/multi_task_learning.py", line 687, in feed_batch decoder=self.model.decoders[task_name]), File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/components/mtl/tasks/init.py", line 182, in feed_batch return decoder(h, batch=batch, mask=mask) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, kwargs) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/components/mtl/tasks/ner/tag_ner.py", line 35, in forward contextualized_embeddings = self.secondary_encoder(contextualized_embeddings, mask=mask) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/layers/transformers/relative_transformer.py", line 309, in forward x = layer(x, mask) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, kwargs) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/layers/transformers/relative_transformer.py", line 263, in forward x = self.self_attn(x, mask) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, *kwargs) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/hanlp/layers/transformers/relative_transformer.py", line 135, in forward qv = self.qv_linear(x) # batch_size x max_len x d_model2 File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(input, **kwargs) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 103, in forward return F.linear(input, self.weight, self.bias) File "/home/xy/miniconda3/envs/py364_xy/lib/python3.6/site-packages/torch/nn/functional.py", line 1848, in linear return torch._C._nn.linear(input, weight, bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (800x256 and 768x1536)

hankcs commented 2 years ago

不是bug,你既然将Electra从base改成了small,就应当将secondary_encoder的维度改成small的256,即:

secondary_encoder=RelativeTransformerEncoder(768, k_as_x=True),
--->
secondary_encoder=RelativeTransformerEncoder(256, k_as_x=True),
Yumeka999 commented 2 years ago

感谢何老师的指正!