wouterkool / attention-learn-to-route

Attention based model for learning to solve different routing problems
MIT License
1.08k stars 341 forks source link

Reimplementation in RL4CO #58

Open fedebotu opened 12 months ago

fedebotu commented 12 months ago

Hi there 👋🏼

First of all, thanks a lot for your library, it has inspired several works in our research group! We are actively developing RL4CO, a library for all things Reinforcement Learning for Combinatorial Optimization. We started the library by modularizing the Attention Model, which is the basis for several other autoregressive models. We also used some recent software (such as TorchRL, TensorDict, PyTorch Lightning and Hydra) as well as routines such as FlashAttention, and made everything as easy to use as possible in the hope of helping practitioners and researchers.

We welcome you to check RL4CO out ^^

wouterkool commented 12 months ago

Hi! Thanks for bringing it to my attention, I will definitely check it out! Are you able to reproduce the results from the paper with your implementation (training time, evaluation time and performance?).

fedebotu commented 12 months ago

Thanks for your quick answer 🚀

We would be more than happy to address your feedback if you check out RL4CO, you may contact us any time 😄

fedebotu commented 9 months ago

(Late) edit: now we are way more efficient as explained here!

wouterkool commented 8 months ago

Great! I have added a link in the readme. However, I wonder if you have also had a look at https://github.com/cpwan/RLOR, they claim a 8x speedup over this repo using PPO.

fedebotu commented 8 months ago

Yes, we are aware of it! From our understanding and our testing, their speedup is actually considered as the time to reach a certain cost as seen in Table 4. AM trains TSP 50 reaching 5.80 in 24 hours, while their PPO implementation - with some training and testing tricks - takes 3 hours. So, it is not a speedup per se - actually, due to the environment being in Numpy, even though vectorized, data collection is naturally a bottleneck - but rather, the time to reach a target performance. Besides, the comparison is made with their AM trained with PPO with larger batch size and learning rate and tested with the "multi-greedy" decoding scheme during inference (what in RL4CO we call multistart, i.e., the POMO decoding scheme that starts decoding from all nodes and then takes the best trajectory), while the baseline AM is just evaluated with one-shot greedy decoding. For these reasons, we think the 8x speedup claim is a bit overstated 👀