Open fedebotu opened 12 months ago
Hi! Thanks for bringing it to my attention, I will definitely check it out! Are you able to reproduce the results from the paper with your implementation (training time, evaluation time and performance?).
Thanks for your quick answer 🚀
true
- the bump in the performance of the last 20 epochs is a step of MultiStepLR
scheduler.mask_inner
to False
enables the FlashAttention routine during decoding (so for each step) and the training above takes 12.5 hours, even with the current slow TensorDict problem! This of course, does degrade performance, but FlashAttention with masking may be added in the near future, so it holds good promise!We would be more than happy to address your feedback if you check out RL4CO, you may contact us any time 😄
Great! I have added a link in the readme. However, I wonder if you have also had a look at https://github.com/cpwan/RLOR, they claim a 8x speedup over this repo using PPO.
Yes, we are aware of it! From our understanding and our testing, their speedup is actually considered as the time to reach a certain cost as seen in Table 4. AM trains TSP 50 reaching 5.80 in 24 hours, while their PPO implementation - with some training and testing tricks - takes 3 hours. So, it is not a speedup per se - actually, due to the environment being in Numpy, even though vectorized, data collection is naturally a bottleneck - but rather, the time to reach a target performance. Besides, the comparison is made with their AM trained with PPO with larger batch size and learning rate and tested with the "multi-greedy" decoding scheme during inference (what in RL4CO we call multistart, i.e., the POMO decoding scheme that starts decoding from all nodes and then takes the best trajectory), while the baseline AM is just evaluated with one-shot greedy decoding. For these reasons, we think the 8x speedup claim is a bit overstated 👀
Hi there 👋🏼
First of all, thanks a lot for your library, it has inspired several works in our research group! We are actively developing RL4CO, a library for all things Reinforcement Learning for Combinatorial Optimization. We started the library by modularizing the Attention Model, which is the basis for several other autoregressive models. We also used some recent software (such as TorchRL, TensorDict, PyTorch Lightning and Hydra) as well as routines such as FlashAttention, and made everything as easy to use as possible in the hope of helping practitioners and researchers.
We welcome you to check RL4CO out ^^