Several refactorings and features made for decoding in RL4CO.
Note that this PR follows #152 but restructured - now working and tested!
[Refactoring] Now models return logits by default (e.g. here). We do so since logits represent the raw outputs from the model, and we would like to decouple the modeling part to how we sample distributions. The function handling the transfer from logits to log probabilities (hence the "log_p") is process_logits
[Feat] New decoding strategies!
We introduce nucleus sampling (i.e. top-p sampling) which discards from the distribution values under a certain threshold in the CDF before sampling. This can be used by simply passing a top_p > 0 to the DecodingStrategy, i.e. to the model deooder. This is ubiquitous in LLMs and it is about time to have it!
Similarly we introduce top-k sampling which considers simply the top-k highest logits in the distribution to sample from
[Refactoring, breaking change] now by default any mask has the same behavior (example here), i.e., the value 1 means keep (i.e. feasible action) while 0 means to remove, i.e. infeasible. This is both similar to TorchRL's action mask and importantly to PyTorch's scaled_dot_product_attention: "A boolean mask where a value of True indicates that the element should take part in attention. " (ref). For this reason, masks that used to have inconsistent namings now have the same behavior
Transfer all decoding utilities (such as decoding strategies, log likelihood calculation and so on) under rl4co/utilshere
Types of changes
[x] New feature (non-breaking change which adds core functionality)
[x] Breaking change (fix or feature that would cause existing functionality to change)
Checklist
[x] My change requires a change to the documentation.
[x] I have updated the tests accordingly (required for a bug fix or a new feature).
CC: @LTluttmann could you have a look if you spot some inefficiencies or if you have some ideas?
CC: @Furffico @cbhua (note: made several modifications compared to previous PR)
Description
Several refactorings and features made for decoding in RL4CO. Note that this PR follows #152 but restructured - now working and tested!
process_logits
top_p
> 0 to theDecodingStrategy
, i.e. to the model deooder. This is ubiquitous in LLMs and it is about time to have it!scaled_dot_product_attention
: "A boolean mask where a value of True indicates that the element should take part in attention. " (ref). For this reason, masks that used to have inconsistent namings now have the same behaviorLogitAttention
toPointerAttention
(for consistency with the Pointer mechanism in Vinyals et al., 2015)rl4co/utils
hereTypes of changes
Checklist
CC: @LTluttmann could you have a look if you spot some inefficiencies or if you have some ideas? CC: @Furffico @cbhua (note: made several modifications compared to previous PR)