wouterkool / attention-learn-to-route

Attention based model for learning to solve different routing problems
MIT License
1.04k stars 337 forks source link

Evaluation with the pretrained model for TSP problem asserts an error #34

Open th-yoon opened 3 years ago

th-yoon commented 3 years ago

Command: python3 eval.py data/tsp/tsp20_validation_seed4321.pkl --model pretrained/tsp_20 --decode_strategy greedy

Trace:

  [*] Loading model from pretrained/tsp_20/epoch-99.pt
  0%|                                                                                            | 0/10 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "eval.py", line 216, in <module>
    eval_dataset(dataset_path, width, opts.softmax_temperature, opts)
  File "eval.py", line 70, in eval_dataset
    results = _eval_dataset(model, dataset, width, softmax_temp, opts, device)
  File "eval.py", line 140, in _eval_dataset
    sequences, costs = model.sample_many(batch, batch_rep=batch_rep, iter_rep=iter_rep)
  File "/home/thyoon/attention-learn-to-route/nets/attention_model.py", line 288, in sample_many
    batch_rep, iter_rep
  File "/home/thyoon/attention-learn-to-route/utils/functions.py", line 189, in sample_many
    _log_p, pi = inner_func(input)
  File "/home/thyoon/attention-learn-to-route/nets/attention_model.py", line 285, in <lambda>
    lambda input: self._inner(*input),  # Need to unpack tuple into arguments
  File "/home/thyoon/attention-learn-to-route/nets/attention_model.py", line 234, in _inner
    batch_size = state.ids.size(0)
  File "/home/thyoon/attention-learn-to-route/problems/tsp/state_tsp.py", line 31, in __getitem__
    assert torch.is_tensor(key) or isinstance(key, slice)  # If tensor, idx all tensors by this tensor:
AssertionError
th-yoon commented 3 years ago

I fixed this issue by commenting out the __getitem__ overrides of the following classes: attention_model/AttentionModelFixed and state_tsp/StateTSP

wouterkool commented 3 years ago

Which version of Python are you using? I think this is related to using <3.8. Please make sure you use 3.8. Commenting out __getitem__ probably works fine as long as you're not using beam search (which requires indexing into the state array).

th-yoon commented 3 years ago

@wouterkool Thanks for your quick reply. Current python version is 3.7.9. Will try with python >= 3.8.