yuqianghan / editretro

Retrosynthesis Prediction with an Iterative String Editing Model
MIT License
10 stars 3 forks source link

运行到generate时有错误 #11

Open mspythontu opened 2 months ago

mspythontu commented 2 months ago

在iterative_refinement_generator.py中运行到word_ins_score, word_ins_pred = word_ins_score.topk(token_beam, dim=-1)时报错: sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-generate')()) File "/home/top/projects/editretro/fairseq/fairseq_cli/generate.py", line 207, in cli_main main(args) File "/home/top/projects/editretro/fairseq/fairseq_cli/generate.py", line 105, in main hypos = task.inference_step(generator, models, sample, prefix_tokens) File "/home/top/projects/editretro/fairseq/fairseq/tasks/fairseq_task.py", line 351, in inference_step return generator.generate(models, sample, prefix_tokens=prefix_tokens) File "/home/top/anaconda3/envs/editretro/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/home/top/projects/editretro/editretro/models/iterative_refinement_generator.py", line 305, in generate decoder_out = model.forward_decoder_token( File "/home/top/projects/editretro/editretro/models/editretro_nat.py", line 732, in forward_decoder_token word_ins_score, word_ins_pred = word_ins_score.topk(token_beam, dim=-1) RuntimeError: CUDA error: an illegal memory access was encountered

word_ins_score的值是: tensor([[[ -8.4584, -8.1290, -7.4481, ..., -8.1291, -8.1291, -8.1291], [-10.1991, -10.8062, -11.1370, ..., -10.8064, -10.8062, -10.8063], [ -9.7229, -10.7866, -10.7560, ..., -10.7868, -10.7867, -10.7868], ..., [ -9.8793, -10.6177, -10.9736, ..., -10.6178, -10.6177, -10.6178], [ -9.8793, -10.6177, -10.9736, ..., -10.6178, -10.6177, -10.6178], [ -9.8793, -10.6177, -10.9736, ..., -10.6178, -10.6177, -10.6178]],

    [[ -6.7876,  -8.2179,  -6.8007,  ...,  -8.2180,  -8.2179,  -8.2180],
     [ -9.6962, -11.2036, -11.5285,  ..., -11.2039, -11.2037, -11.2038],
     [ -9.1220, -10.7852, -10.6264,  ..., -10.7854, -10.7853, -10.7854],
     ...,
     [ -9.9026, -11.0796, -11.4898,  ..., -11.0799, -11.0796, -11.0798],
     [ -9.9026, -11.0796, -11.4898,  ..., -11.0799, -11.0796, -11.0798],
     [ -9.9026, -11.0796, -11.4898,  ..., -11.0799, -11.0796, -11.0798]],

    [[ -8.1274,  -8.1378,  -6.5826,  ...,  -8.1379,  -8.1379,  -8.1379],
     [ -9.9285, -11.2284, -11.5292,  ..., -11.2286, -11.2284, -11.2285],
     [ -9.8055, -10.8103, -10.7344,  ..., -10.8105, -10.8105, -10.8105],
     ...,
     [ -9.8142, -10.6868, -10.9587,  ..., -10.6869, -10.6868, -10.6869],
     [ -9.8142, -10.6868, -10.9587,  ..., -10.6869, -10.6868, -10.6869],
     [ -9.8142, -10.6868, -10.9587,  ..., -10.6869, -10.6868, -10.6869]],

    ...,

    [[ -8.6219,  -8.2528,  -6.6367,  ...,  -8.2529,  -8.2529,  -8.2528],
     [-10.0334, -10.8612, -11.3149,  ..., -10.8614, -10.8613, -10.8613],
     [ -9.6759, -10.8410, -10.8632,  ..., -10.8412, -10.8412, -10.8412],
     ...,
     [-10.3744,  -9.4961,  -9.0727,  ...,  -9.4963,  -9.4962,  -9.4962],
     [-10.3744,  -9.4961,  -9.0727,  ...,  -9.4963,  -9.4962,  -9.4962],
     [-10.3744,  -9.4961,  -9.0727,  ...,  -9.4963,  -9.4962,  -9.4962]],

    [[ -8.9653,  -8.1965,  -6.6784,  ...,  -8.1966,  -8.1966,  -8.1966],
     [-10.2569, -10.8919, -11.2553,  ..., -10.8920, -10.8920, -10.8919],
     [ -9.5705, -10.2760, -10.5224,  ..., -10.2762, -10.2762, -10.2762],
     ...,
     [-11.1796, -10.5633, -10.1849,  ..., -10.5634, -10.5634, -10.5634],
     [-11.1796, -10.5633, -10.1849,  ..., -10.5634, -10.5634, -10.5634],
     [-11.1796, -10.5633, -10.1849,  ..., -10.5634, -10.5634, -10.5634]],

    [[ -8.2500,  -8.2720,  -1.9953,  ...,  -8.2719,  -8.2720,  -8.2718],
     [ -8.6661, -10.7750, -12.0660,  ..., -10.7750, -10.7751, -10.7749],
     [ -8.1770,  -9.1500,  -9.6314,  ...,  -9.1500,  -9.1500,  -9.1500],
     ...,
     [ -9.7259, -11.0957, -10.8102,  ..., -11.0956, -11.0956, -11.0956],
     [ -9.7259, -11.0957, -10.8102,  ..., -11.0956, -11.0956, -11.0956],
     [ -9.7259, -11.0957, -10.8102,  ..., -11.0956, -11.0956, -11.0956]]],
   device='cuda:0')
   是否显存有要求
yuqianghan commented 1 month ago
  1. 你好,首先请检查 fairseq 安装是否成功,主要确认 CUDA (11.6.0) 和 gcc (9.4.0) 版本。正确安装以后会有 fairseq/fairseq/libnat_cuda.cpython-310-x86_64-linux-gnu.so 文件生成。
  2. generate 对显存是有要求的,请参考最新的 1_generate_50k.sh 脚本,更改 max_tokens 参数。