I'm trying to run the training script with Python 3.8.10 and torch==1.10.2+cu113, and I obtain the following error:
>> bash thualign/bin/train.sh -s mask_align -e agree_deen
running mask_align
Traceback (most recent call last):
File "/net/aistaff/sarti/Mask-Align/thualign/bin/trainer.py", line 21, in <module>
import thualign.data as data
File "/net/aistaff/sarti/Mask-Align/thualign/data/__init__.py", line 5, in <module>
from thualign.data.dataset import Dataset, TextLineDataset
File "/net/aistaff/sarti/Mask-Align/thualign/data/dataset.py", line 51, in <module>
class Dataset(IterableDataset):
File "/net/aistaff/sarti/Mask-Align/venv/lib/python3.8/site-packages/torch/utils/data/_typing.py", line 273, in __new__
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
File "/usr/lib/python3.8/abc.py", line 85, in __new__
cls = super().__new__(mcls, name, bases, namespace, **kwargs)
File "/net/aistaff/sarti/Mask-Align/venv/lib/python3.8/site-packages/torch/utils/data/_typing.py", line 373, in _dp_init_subclass
raise TypeError("Expected 'Iterator' as the return annotation for `__iter__` of {}"
TypeError: Expected 'Iterator' as the return annotation for `__iter__` of Dataset, but found thualign.data.iterator.Iterator
Do you have a specific pinned version of torch to make the script work?
Hi,
I'm trying to run the training script with Python 3.8.10 and
torch==1.10.2+cu113
, and I obtain the following error:Do you have a specific pinned version of torch to make the script work?