ai4co / rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)
https://rl4.co
MIT License
381 stars 70 forks source link

[Feat, Refactor] Decoding: logits handling, new decoding strategies, standardize masks #161

Closed fedebotu closed 5 months ago

fedebotu commented 5 months ago

Description

Several refactorings and features made for decoding in RL4CO. Note that this PR follows #152 but restructured - now working and tested!

  1. [Refactoring] Now models return logits by default (e.g. here). We do so since logits represent the raw outputs from the model, and we would like to decouple the modeling part to how we sample distributions. The function handling the transfer from logits to log probabilities (hence the "log_p") is process_logits
  2. [Feat] New decoding strategies!
    • We introduce nucleus sampling (i.e. top-p sampling) which discards from the distribution values under a certain threshold in the CDF before sampling. This can be used by simply passing a top_p > 0 to the DecodingStrategy, i.e. to the model deooder. This is ubiquitous in LLMs and it is about time to have it!
    • Similarly we introduce top-k sampling which considers simply the top-k highest logits in the distribution to sample from
  3. [Refactoring, breaking change] now by default any mask has the same behavior (example here), i.e., the value 1 means keep (i.e. feasible action) while 0 means to remove, i.e. infeasible. This is both similar to TorchRL's action mask and importantly to PyTorch's scaled_dot_product_attention: "A boolean mask where a value of True indicates that the element should take part in attention. " (ref). For this reason, masks that used to have inconsistent namings now have the same behavior
  4. [Minor] Rename LogitAttention to PointerAttention (for consistency with the Pointer mechanism in Vinyals et al., 2015)
  5. Transfer all decoding utilities (such as decoding strategies, log likelihood calculation and so on) under rl4co/utils here

Types of changes

Checklist


CC: @LTluttmann could you have a look if you spot some inefficiencies or if you have some ideas? CC: @Furffico @cbhua (note: made several modifications compared to previous PR)

fedebotu commented 5 months ago

Awesome refractory! 🚀 I will try to create an UML graph for the overview pipeline. Could be helpful for users to understand.

Merging now!