yabufarha / ms-tcn

Other
214 stars 58 forks source link

bug #8

Closed PolynomialQian closed 5 years ago

PolynomialQian commented 5 years ago

Hi,I want to reproduce your code, but it seems to have encountered a problem during the running process. Is the code you gave correct? And I am running on the version of pytorch1.0.1, I don't know if this is the reason for my error, please take a look. My error is as follows, I am disturbed. 【Traceback (most recent call last): File "/Users/polypubki/Downloads/adata/main.py", line 72, in trainer.train(model_dir, batch_gen, num_epochs=num_epochs, batch_size=bz, learning_rate=lr, device=device) File "/Users/polypubki/Downloads/adata/model.py", line 79, in train batch_input, batch_target, mask = batch_gen.next_batch(batch_size) File "/Users/polypubki/Downloads/adata/batch_gen.py", line 56, in next_batch batch_target_tensor = torch.ones(len(batch_input), max(length_of_sequences), dtype=torch.long)*(-100) ValueError: max() arg is an empty sequence】

yabufarha commented 5 years ago

Hi, The code is tested with PyTorch 0.4.1 but I don't think that the pytorch version is the problem. Whats you python version?

PolynomialQian commented 5 years ago

my pytorch is 1.0

PolynomialQian commented 5 years ago

I can't train the network with the code you gave.....

yabufarha commented 5 years ago

actually I'm asking about the python version not pytorch

PolynomialQian commented 5 years ago

sorry,and my python version is 3.7,is this the problem?

yabufarha commented 5 years ago

Yes, this is the problem. The code works with python 2.7 and if you want to run it with python 3.7 then you have to do some changes. I think for the error that you have you need to modify line 49 in batch_gen.py For python2.7: length_of_sequences = map(len, batch_target) For python3.7: length_of_sequences = list(map(len, batch_target))