ai4co / rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)
https://rl4.co
MIT License
386 stars 71 forks source link

[Feat] Decoding refactoring #152

Closed fedebotu closed 5 months ago

fedebotu commented 5 months ago

Description

Several refactorings and features made for decoding in RL4CO.

  1. [Refactoring] Now models return logits by default (e.g. here). We do so since logits represent the raw outputs from the model, and we would like to decouple the modeling part to how we sample distributions. The function handling the transfer from logits to probabilities (hence the "log_p") is logit_to_probs
  2. [Feat] New decoding strategy: we introduce nucleus sampling (i.e. top-p sampling) which discards from the distribution values under a certain threshold in the CDF before sampling. This can be used by simply passing a top_p > 0 to the DecodingStrategy, i.e. to the model deooder. This is ubiquitous in LLMs and it is about time to have it!
  3. [Refactoring]: for simplicity we now default to handling probabilities instead of log probabilities (example here). This is a minor change, but it can make the code more readable and avoid having to do logp.exp() when sampling. This is also more in line with recent works in e.g. LLM
  4. [Refactoring, breaking change] now by default any mask has the same behavior (example here), i.e., the value 1 means keep (i.e. feasible action) while 0 means to remove, i.e. infeasible. This is both similar to TorchRL's action mask and importantly to PyTorch's scaled_dot_product_attention: "A boolean mask where a value of True indicates that the element should take part in attention. " (ref). For this reason, masks that used to have inconsistent namings now have the same behavior
  5. [Minor] Rename LogitAttention to PointerAttention (for consistency with the Pointer mechanism in Vinyals et al., 2015)

[!WARNING] Work in progress. Do not merge yet. Some checks and training still have some bugs that need to be fixed (most probably due to the new masking

Types of changes

Checklist


CC: @LTluttmann could you have a look if you spot some inefficiencies or if you have some ideas? CC: @Furffico @cbhua these changess are what I was talking about yesterday (note that in this case running the softmax normalization inside the Sampling in ACO might not be needed)

fedebotu commented 5 months ago

Closed in favor of #161