agiresearch / OpenAGI

OpenAGI: When LLM Meets Domain Experts
MIT License
1.82k stars 151 forks source link

Incorrect categorical distribution setting #33

Closed lzl65825 closed 7 months ago

lzl65825 commented 7 months ago

In RLTF, the codes use torch.distributions.Categorical to sample actions. For example, in line 142 of benchmark_tasks/rltf/rltf_schema_flan_t5.py:

action = torch.distributions.Categorical(torch.stack(log_prob).detach()).sample()

However, if the args of Categorical are not designated, it will use probs instead of logits. Thus, it should be

action = torch.distributions.Categorical(logits=torch.stack(log_prob).detach()).sample()