danliu2 / caat

MIT License
34 stars 2 forks source link

text-to-text simultaneous translation #7

Closed RaynerWu closed 2 years ago

RaynerWu commented 2 years ago

hi, 可以提供一个text-to-text simultaneous translation的例子吗?感谢!

danliu2 commented 2 years ago

text-to-text模型训练直接使用fairseq-preprocess得到的数据集,将arch指定为mt_mha(rain/model/transducer.py)即可,需注意因为joiner显存消耗过大,batch-size需要设置比常规模型小2倍左右。推理时使用simuleval工具启动agent:caat/rain/simul/text_fullytransducer_agent.py,参数和speech-to-text类似。 非常抱歉目前开源的实验代码组织过于粗糙,我过些时间会提供更详细的说明和示例代码,并重构原代码中一些晦涩的地方。

RaynerWu commented 2 years ago

感谢回复

danliu2 commented 2 years ago

你好,gpu_rnnt_delay是C++代码实现函数。你需要安装我目录里提供的warp_rnnt版本:先cmake &make install,然后在pytorch_binding目录下pip install。需要注意避免和常规版warp_rnnt冲突。 非常抱歉此前代码整理过于粗糙,稍后我会在时间允许情况下重构这部分。

发件人: havefun @. 发送时间: 2022年1月17日 15:09 收件人: danliu2/caat @.> 抄送: danliu2 @.>; Comment @.> 主题: Re: [danliu2/caat] text-to-text simultaneous translation (Issue #7)

你好,运行时遇到了一个错误 File "/caat/warp_transducer/pytorch_binding/warprnnt_pytorch/delay_transducer.py", line 56, in forward loss_func = warp_rnnt.gpu_rnnt_delay AttributeError: module 'warprnnt_pytorch' has no attribute 'gpu_rnnt_delay'

— Reply to this email directly, view it on GitHubhttps://github.com/danliu2/caat/issues/7#issuecomment-1014203816, or unsubscribehttps://github.com/notifications/unsubscribe-auth/AA3ZUTTAOISBCLXATEZRCGTUWO6ANANCNFSM5LVFSEOQ. Triage notifications on the go with GitHub Mobile for iOShttps://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Androidhttps://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub. You are receiving this because you commented.Message ID: @.**@.>>

dearchill commented 2 years ago

text-to-text模型训练直接使用fairseq-preprocess得到的数据集,将arch指定为mt_mha(rain/model/transducer.py)即可,需注意因为joiner显存消耗过大,batch-size需要设置比常规模型小2倍左右。推理时使用simuleval工具启动agent:caat/rain/simul/text_fullytransducer_agent.py,参数和speech-to-text类似。 非常抱歉目前开源的实验代码组织过于粗糙,我过些时间会提供更详细的说明和示例代码,并重构原代码中一些晦涩的地方。

@danliu2 hi我想继续请教一下text-to-text的训练参数,我目前参照您scripts下面train_transducer.sh的参数设置,然后arch换成mt_mha,但是遇到了一些问题,首先是训练过程中loss一直是nan和-inf,然后训练了不到2万步,用fairseq-generate生成译文出现了明显的漏译和鬼畜现象,我的参数如下:

fairseq-train $data/data-bin --source-lang en --target-lang zh \ --user-dir rain \ --ddp-backend=no_c10d \ --task transducer --task-type mt \ --bpe-dropout 0.1 \ --arch mt_mha --dropout 0.3 --activation-dropout 0.1 \ --share-decoder-input-output-embed \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ --warmup-updates 4000 --warmup-init-lr '1e-07' \ --criterion fake_loss \ --weight-decay 0.0001 \ --save-dir $data/checkpoints \ --tensorboard-logdir $data/tensorboard \ --clip-norm 2 \ --max-tokens 4096 \ --update-freq 8 \ --keep-interval-updates 15 --save-interval-updates 3000 --log-interval 100 \ --no-progress-bar

下面是部分翻译结果:

S-1501 ▁just ▁so ▁so T-1501 ▁一般 H-1501 -2.3230488300323486 ▁那么 ▁。 D-1501 -2.3230488300323486 ▁那么 ▁。 P-1501 -3.0586 -3.7270 -0.1835 S-741 ▁hello ▁, ▁everyone T-741 ▁大家 ▁好 H-741 -2.028507709503174 ▁ , ▁每个 ▁人 ▁, ▁每个 ▁人 ▁, ▁你们 ▁, ▁每个 ▁人 ▁, ▁每个 ▁人 ▁ ! D-741 -2.028507709503174 ▁ , ▁每个 ▁人 ▁, ▁每个 ▁人 ▁, ▁你们 ▁, ▁每个 ▁人 ▁, ▁每个 ▁人 ▁ ! P-741 -3.3551 -1.8159 -2.1910 -1.0341 -1.6109 -3.1892 -0.4419 -1.7155 -2.6989 -2.1088 -3.3423 -0.4281 -1.6798 -3.5302 -0.4099 -2.9969 -2.3451 -1.6195 1%|█▉ | 22/1521 [00:02<01:35, 15.66it/s, wps=42]S-804 ▁is ▁it ▁far T-804 ▁远 ▁吗 H-804 -2.832214117050171 ▁? D-804 -2.832214117050171 ▁? P-804 -4.1542 -1.5103 S-899 ▁where ▁is ▁it T-899 ▁在 ▁哪里 H-899 -2.2504234313964844 ▁? D-899 -2.2504234313964844 ▁? P-899 -3.6326 -0.8682 S-954 ▁return ▁to ▁zero T-954 ▁归 零 H-954 -1.6095635890960693 ▁返回 ▁ 零 件 D-954 -1.6095635890960693 ▁返回 ▁ 零 件 P-954 -2.1104 -3.0526 -0.5272 -1.1991 -1.1585 2%|██▎ | 25/1521 [00:02<01:23, 17.81it/s, wps=44]S-1283 ▁follow ▁the ▁rules T-1283 ▁遵守 ▁规则 H-1283 -1.873369812965393 ▁遵循 ▁规则 D-1283 -1.873369812965393 ▁遵循 ▁规则 P-1283 -4.0500 -0.2057 -1.3644 S-1293 ▁here ▁it ▁is T-1293 ▁在 ▁这儿 H-1293 -2.5336668491363525 ▁ 这 ▁是 D-1293 -2.5336668491363525 ▁ 这 ▁是 P-1293 -4.2072 -2.1468 -2.2899 -1.4908 S-1296 ▁how ▁about ▁you T-1296 ▁你 ▁呢 H-1296 -2.6396994590759277 ▁) ▁。 D-1296 -2.6396994590759277 ▁) ▁。 P-1296 -4.3705 -2.3755 -1.1730

有几个疑惑的地方想要确认一下,一个是task和task-type参数,另一个是arch和criterion参数,还有就是训练时需要读取一个text_cfg的yaml文件,我看到是在text_encoder.py里面用到的,如果我的数据都是先tokenize然后bpe,再经过fairseq-preprocess处理过的,是不是就不需要经过这个了。期待您的回复。

danliu2 commented 2 years ago

@dearchill 抱歉刚看到这个问题,您在一个closed issue里的回复我这边被忽略了。问题是否已经解决?如果仍有问题希望以下回复有帮助:

  1. task 中trasnducer task只是为了适配caat loss写的一个临时类,因为CAAT 训练时forward计算和loss耦合在一起;task-type包括asr,mt和st,用来确定读取语音翻译数据时哪些作为src和tgt。
  2. 文本翻译并且不做bpe-dropout时text_cfg应该用不到(但前面代码变量检查时有可能会检查对应文件是否存在)
  3. loss 为NAN的问题:有可能在初期某些步骤计算出了inf的loss(CAAT使用的前后向算法对某些特殊样本,比如token为0可能计算数字异常),而打印结果为moving average 的loss,出现一次inf后所有都是nan,你可以检查一下你的数据,另外在logging时可以用isnan过滤一下。
  4. CAAT模型解码机制上有点特别,需要指定解码bos_token是(0,fairseq默认为eos,2)。您解码异常很可能是这个原因。 想到的大概是这些。如果您有其它问题可以开新的issue讨论。或者直接邮件liudanche@hotmail.com,期待和您更多讨论。