ai4co / rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)
https://rl4.co
MIT License
455 stars 84 forks source link

Convert TSP to ATSP #108

Closed Mu-Yanchen closed 5 months ago

Mu-Yanchen commented 11 months ago

Describe the bug

I used the notebook in the link below to learn about rl4co(https://github.com/ai4co/rl4co/blob/main/notebooks/tutorials/2-creating-new-env-model.ipynb). I now want to verify the ATSP method, so I import ATSPEnv instead of TSPEnv like this:

batch_size = 2
from rl4co.envs import ATSPEnv
env_atsp = ATSPEnv(num_loc=30)
reward, td, actions = rollout(env_atsp, env_atsp.reset(batch_size=[batch_size]), random_policy)
env_atsp.render(td, actions)

which run correctly but when I Rollout untrained model like below, I encounter the following bugs:

Greedy rollouts over untrained model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td_init_atsp = env_atsp.reset(batch_size=[3]).to(device)
model_atsp = model_atsp.to(device)
out_atsp = model_atsp(td_init_atsp.clone(), phase="test", decode_type="greedy", return_actions=True)
actions_untrained = out_atsp['actions'].cpu().detach()
rewards_untrained = out_atsp['reward'].cpu().detach()

for i in range(3):
    print(f"Problem {i+1} | Cost: {-rewards_untrained[i]:.3f}")
    env_atsp.render(td_init_atsp[i], actions_untrained[i])

bugs are:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[5], line 5
      3 td_init_atsp = env_atsp.reset(batch_size=[3]).to(device)
      4 model_atsp = model_atsp.to(device)
----> 5 out_atsp = model_atsp(td_init_atsp.clone(), phase="test", decode_type="greedy", return_actions=True)
      6 actions_untrained = out_atsp['actions'].cpu().detach()
      7 rewards_untrained = out_atsp['reward'].cpu().detach()

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/rl/common/base.py:246, in RL4COLitModule.forward(self, td, **kwargs)
    244     log.info("Using env from kwargs")
    245     env = kwargs.pop("env")
--> 246 return self.policy(td, env, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/zoo/common/autoregressive/policy.py:140, in AutoregressivePolicy.forward(self, td, env, phase, return_actions, return_entropy, return_init_embeds, **decoder_kwargs)
    125 """Forward pass of the policy.
    126 
    127 Args:
   (...)
    136     out: Dictionary containing the reward, log likelihood, and optionally the actions and entropy
    137 """
    139 # ENCODER: get embeddings from initial state
--> 140 embeddings, init_embeds = self.encoder(td)
    142 # Instantiate environment if needed
    143 if isinstance(env, str) or env is None:

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/zoo/common/autoregressive/encoder.py:74, in GraphAttentionEncoder.forward(self, td, mask)
     62 """Forward pass of the encoder.
     63 Transform the input TensorDict into a latent representation.
     64 
   (...)
     71     init_h: Initial embedding of the input
     72 """
     73 # Transfer to embedding space
---> 74 init_h = self.init_embedding(td)
     76 # Process embedding
     77 h = self.net(init_h, mask)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~/miniconda3/lib/python3.10/site-packages/rl4co/models/nn/env_embeddings/init.py:49, in TSPInitEmbedding.forward(self, td)
     48 def forward(self, td):
---> 49     out = self.init_embed(td["locs"])
     50     return out

File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:3697, in TensorDictBase.__getitem__(self, index)
   3695     idx_unravel = _unravel_key_to_tuple(index)
   3696     if idx_unravel:
-> 3697         return self._get_tuple(idx_unravel, NO_DEFAULT)
   3698 if (istuple and not index) or (not istuple and index is Ellipsis):
   3699     # empty tuple returns self
   3700     return self

File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:4625, in TensorDict._get_tuple(self, key, default)
   4624 def _get_tuple(self, key, default):
-> 4625     first = self._get_str(key[0], default)
   4626     if len(key) == 1 or first is default:
   4627         return first

File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:4621, in TensorDict._get_str(self, key, default)
   4619 out = self._tensordict.get(first_key, None)
   4620 if out is None:
-> 4621     return self._default_get(first_key, default)
   4622 return out

File ~/miniconda3/lib/python3.10/site-packages/tensordict/tensordict.py:1455, in TensorDictBase._default_get(self, key, default)
   1452     return default
   1453 else:
   1454     # raise KeyError
-> 1455     raise KeyError(
   1456         TensorDictBase.KEY_ERROR.format(
   1457             key, self.__class__.__name__, sorted(self.keys())
   1458         )
   1459     )

KeyError: 'key "locs" not found in TensorDict with keys [\'action_mask\', \'cost_matrix\', \'current_node\', \'done\', \'first_node\', \'i\', \'terminated\']'

Reason and Possible fixes

I think the problem is the mismatch between model and ATSPEnv, but I have not found a solution. Thank you for your time and attention

Mu-Yanchen commented 11 months ago

By the way, how should I train an ATSP model like a TSP model

Haimrich commented 10 months ago

I think the problem comes from https://github.com/ai4co/rl4co/blob/1a2da37d6104c33646f74bb4b040d2a4006876c2/rl4co/models/nn/env_embeddings/init.py#L16 Basically by default, the same InitEmbedding used for TSP is used for the ATSP environment. The issue is that in TSP you can just embed the coordinates of each node ('locs' in the TSP environment) and make the encoder infer the euclidean distance, while in the ATSP I think you can't because all you have is an asymmetric distance matrix ('cost_matrix' in the ATSP environment) and giving the encoder the coordinates of each node would not help it understand why going from one to the other has a cost and going back has another.

So I think that in order to solve the ATSP problem with the AM model you need a custom InitEmbedding that encodes the nodes in such a way that you also provide information about the asymmetric distance matrix. Maybe a GNN or something like that.

cbhua commented 9 months ago

Hi @Mu-Yanchen, thanks for raising this bug and sorry for our late reply. Also thanks to @Haimrich's help!

In the current version, we applied the MetNet[1] on the ATSP. Different from other environments, the initial embedding for ATSP is located at here.

We updated the MatNet implementation in b3f1446820fd6c2d9ac3399369ffc134dd86b3ab. You may want to check a minimum testing on this notebook and play with it 🚀.

[1] Kwon, Yeong-Dae, et al. "Matrix encoding networks for neural combinatorial optimization." Advances in Neural Information Processing Systems 34 (2021): 5138-5149.

fedebotu commented 5 months ago

Closing now. Feel free to reopen if any other issue arises! :+1: