danliu2 / caat

MIT License
34 stars 2 forks source link

Offline generate of CAAT model #11

Closed sarapapi closed 2 years ago

sarapapi commented 2 years ago

Hi, I am trying to discover the performance drop that I write about on issue #10 and I tried to make the offline generate of my CAAT model in order to compare its results with those of your paper but when I launch the classical Fairseq generate I receive this message:

transducer temperature= 1.0
encoder. initialsed from /storage/MT/sara/simultaneous/training/ST/MustC/en-es/caat_asrpretrain/checkpoint_best.pt complete
Traceback (most recent call last):
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 411, in <module>
    cli_main()
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 407, in cli_main
    main(args)
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 47, in main
    return _main(cfg, h)
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 170, in _main
    generator = task.build_generator(
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq/../rain/tasks/transducer_task.py", line 35, in build_generator
    raise NotImplementedError("TODO")
NotImplementedError: TODO

What have you used for computing the results of your model offline? Thanks again

danliu2 commented 2 years ago

I used s2s_task in inference offline. sorry for these missing code

sarapapi commented 2 years ago

Hi, thanks for your reply. I switched to s2s task but the following error arises:

Traceback (most recent call last):
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 411, in <module>
    cli_main()
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 407, in cli_main
    main(args)
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 49, in main
    return _main(cfg, sys.stdout)
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 204, in _main
    hypos = task.inference_step(
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq/../rain/tasks/s2s_task.py", line 409, in inference_step
    return generator.generate(
  File "/home/spapi/anaconda3/envs/caat_env/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq/sequence_generator.py", line 182, in generate
    return self._generate(sample, **kwargs)
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq/../rain/sequence_generator2.py", line 189, in _generate
    if self.no_repeat_ngram_size > 0:
  File "/home/spapi/anaconda3/envs/caat_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 778, in __getattr__
    raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
torch.nn.modules.module.ModuleAttributeError: 'SequenceGenerator2' object has no attribute 'no_repeat_ngram_size'
danliu2 commented 2 years ago

fixed in e1cae5f. Thanks for finding this bug, caused by fairseq new version: they removed member no_repeat_ngram_size in init, this is not used in my inference

sarapapi commented 2 years ago

Hi, thanks for your reply. I applied your patch but then another error arises:

2022-02-02 17:01:23 | INFO | fairseq_cli.generate | loading model(s) from /storage/MT/sara/simultaneous/training/ST/MustC/en-es/caat/checkpoint_best.pt
transducer temperature= 1.0
encoder. initialsed from /storage/MT/sara/simultaneous/training/ST/MustC/en-es/caat_asrpretrain/checkpoint_best.pt complete
2022-02-02 17:01:28 | WARNING | fairseq.tasks.fairseq_task | 41 samples have invalid sizes and will be skipped, max_positions=(2000, 512), first few sample ids=[728, 448, 835, 1071, 509, 242, 281, 375, 396, 805]
Traceback (most recent call last):
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 411, in <module>
    cli_main()
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 407, in cli_main
    main(args)
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 49, in main
    return _main(cfg, sys.stdout)
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq_cli/generate.py", line 241, in _main
    src_str = src_dict.string(src_tokens, cfg.common_eval.post_process)
  File "/home/spapi/fairseq_simul/simul_CAAT/fairseq/fairseq/data/dictionary.py", line 99, in string
    sent = " ".join(
TypeError: 'NoneType' object is not iterable

Which code have you used for running the offline generation? Thanks again

danliu2 commented 2 years ago

This may caused from some few changes to Fairseq. In the version of my experiments, that is : if src_dict is not None and src_tokens is not None: src_str= src_dict... else: src_str="" For speech translation, there are no src_tokens here, and these code are just for result logging, so you may bypass them by small modification.