Several refactorings and features made for decoding in RL4CO.
[Refactoring] Now models return logits by default (e.g. here). We do so since logits represent the raw outputs from the model, and we would like to decouple the modeling part to how we sample distributions. The function handling the transfer from logits to probabilities (hence the "log_p") is logit_to_probs
[Feat] New decoding strategy: we introduce nucleus sampling (i.e. top-p sampling) which discards from the distribution values under a certain threshold in the CDF before sampling. This can be used by simply passing a top_p > 0 to the DecodingStrategy, i.e. to the model deooder. This is ubiquitous in LLMs and it is about time to have it!
[Refactoring]: for simplicity we now default to handling probabilities instead of log probabilities (example here). This is a minor change, but it can make the code more readable and avoid having to do logp.exp() when sampling. This is also more in line with recent works in e.g. LLM
[Refactoring, breaking change] now by default any mask has the same behavior (example here), i.e., the value 1 means keep (i.e. feasible action) while 0 means to remove, i.e. infeasible. This is both similar to TorchRL's action mask and importantly to PyTorch's scaled_dot_product_attention: "A boolean mask where a value of True indicates that the element should take part in attention. " (ref). For this reason, masks that used to have inconsistent namings now have the same behavior
[!WARNING]
Work in progress. Do not merge yet. Some checks and training still have some bugs that need to be fixed (most probably due to the new masking
Types of changes
[x] New feature (non-breaking change which adds core functionality)
[x] Breaking change (fix or feature that would cause existing functionality to change)
Checklist
[x] My change requires a change to the documentation.
[ ] I have updated the tests accordingly (required for a bug fix or a new feature).
[ ] I have updated the documentation accordingly.
CC: @LTluttmann could you have a look if you spot some inefficiencies or if you have some ideas?
CC: @Furffico @cbhua these changess are what I was talking about yesterday (note that in this case running the softmax normalization inside theSamplingin ACO might not be needed)
Description
Several refactorings and features made for decoding in RL4CO.
logit_to_probs
top_p
> 0 to theDecodingStrategy
, i.e. to the model deooder. This is ubiquitous in LLMs and it is about time to have it!logp.exp()
when sampling. This is also more in line with recent works in e.g. LLMscaled_dot_product_attention
: "A boolean mask where a value of True indicates that the element should take part in attention. " (ref). For this reason, masks that used to have inconsistent namings now have the same behaviorLogitAttention
toPointerAttention
(for consistency with the Pointer mechanism in Vinyals et al., 2015)Types of changes
Checklist
CC: @LTluttmann could you have a look if you spot some inefficiencies or if you have some ideas? CC: @Furffico @cbhua these changess are what I was talking about yesterday (note that in this case running the softmax normalization inside the
Sampling
in ACO might not be needed)