Closed lzl65825 closed 7 months ago
In RLTF, the codes use torch.distributions.Categorical to sample actions. For example, in line 142 of benchmark_tasks/rltf/rltf_schema_flan_t5.py:
torch.distributions.Categorical
action = torch.distributions.Categorical(torch.stack(log_prob).detach()).sample()
However, if the args of Categorical are not designated, it will use probs instead of logits. Thus, it should be
probs
logits
action = torch.distributions.Categorical(logits=torch.stack(log_prob).detach()).sample()
In RLTF, the codes use
torch.distributions.Categorical
to sample actions. For example, in line 142 of benchmark_tasks/rltf/rltf_schema_flan_t5.py:However, if the args of Categorical are not designated, it will use
probs
instead oflogits
. Thus, it should be