AI4Bharat / IndicTrans2

Translation models for 22 scheduled languages of India
https://ai4bharat.iitm.ac.in/indic-trans2
MIT License
214 stars 59 forks source link

Distillation: Unable to start the training #75

Closed harshyadav17 closed 3 months ago

harshyadav17 commented 3 months ago

hey @PranjalChitale @VarunGumma

Following is the command to start the training bash distill.sh <data_dir> <teacher_ckpt_path>

I am unable to start the training with the given command. I have the latest clone of the branch.


`Traceback (most recent call last):
  File "/opt/conda/envs/itdv2/lib/python3.11/site-packages/torch/multiprocessing/spawn.py", line 75, in _wrap
    fn(i, *args)
  File "/opt/conda/envs/itdv2/lib/python3.11/site-packages/fairseq/distributed/utils.py", line 362, in distributed_main
    main(cfg, **kwargs)
  File "/opt/conda/envs/itdv2/lib/python3.11/site-packages/fairseq_cli/train.py", line 208, in main
    criterion = task.build_criterion(cfg.criterion)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/itdv2/lib/python3.11/site-packages/fairseq/tasks/fairseq_task.py", line 356, in build_criterion
    return criterions.build_criterion(cfg, self, from_checkpoint=from_checkpoint)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/itdv2/lib/python3.11/site-packages/fairseq/criterions/__init__.py", line 29, in build_criterion
    return build_criterion_(cfg, task, from_checkpoint=from_checkpoint)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/itdv2/lib/python3.11/site-packages/fairseq/registry.py", line 65, in build_x
    return builder(cfg, *extra_args, **extra_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/itdv2/lib/python3.11/site-packages/fairseq/criterions/fairseq_criterion.py", line 61, in build_criterion
    return cls(**init_args)
           ^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/itdv2/lib/python3.11/site-packages/fairseq/criterions/label_smoothed_cross_entropy_with_kd.py", line 62, in __init__
    self.kd_rate = kd_args.get("rate", 0.0)
                   ^^^^^^^^^^^
AttributeError: 'str' object has no attribute 'get'` ```
harshyadav17 commented 3 months ago

Temporarily, I have added kd_args = json.loads(kd_args) at line 49 of IndicTrans2/fairseq/fairseq/criterions/label_smoothed_cross_entropy_with_kd.py

VarunGumma commented 3 months ago

@harshyadav17, yes that was a small bug in the library. I have fixed it in fairseq, and you can pull the latest changes. Thanks.