pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.02k stars 21.99k forks source link

Support for Multi-Categorical in torch.distributions #43250

Open youkaichao opened 4 years ago

youkaichao commented 4 years ago

🚀 Feature

Support for Multi-Categorical in torch.distributions

Motivation

As openai gym supports MultiDiscrete space, it would be nice if pytorch can support the corresponding distribution, too. This is also a common scenario, where the action consists of multiple sub-actions, each sub-action has different number of choices.

Pitch

# I have three sub-actions, each have 3, 5, 4 choices, respectively
dist = multi_categorical_maker([3, 5, 4])
# give it batch of logits with 5 items. 
# the first 3 logits corresponds to the first sub-action, 
# the next 5 for the second, 
# and the last 4 for the last sub-action
d = dist(logits=torch.rand(5, 12)) 

# I get 5 actions, the first sub-action ranges in [0, 3), the second in [0, 5), the last in [0, 4)
d.sample()
tensor([[0, 4, 3],
        [0, 4, 0],
        [0, 3, 1],
        [0, 0, 3],
        [0, 0, 3]])

# return the log_prob of actions, which are the sum of log_prob of sub-actions
d.log_prob(d.sample())
tensor([-3.8733, -4.3669, -4.6643, -3.5386, -4.1125])

Alternatives

Possible implementation

The implementation is very simple and straightforward. Here is just an example:

import torch
from torch.distributions import Categorical, Distribution
from typing import List

class MultiCategorical(Distribution):

    def __init__(self, dists: List[Categorical]):
        super().__init__()
        self.dists = dists

    def log_prob(self, value):
        ans = []
        for d, v in zip(self.dists, torch.split(value, 1, dim=-1)):
            ans.append(d.log_prob(v.squeeze(-1)))
        return torch.stack(ans, dim=-1).sum(dim=-1)

    def entropy(self):
        return torch.stack([d.entropy() for d in self.dists], dim=-1).sum(dim=-1)

    def sample(self, sample_shape=torch.Size()):
        return torch.stack([d.sample(sample_shape) for d in self.dists], dim=-1)

def multi_categorical_maker(nvec):
    def get_multi_categorical(logits):
        start = 0
        ans = []
        for n in nvec:
            ans.append(Categorical(logits=logits[:, start: start + n]))
            start += n
        return MultiCategorical(ans)
    return get_multi_categorical

cc @vincentqb @fritzo @neerajprad @alicanb @vishwakftw

MadcowD commented 3 years ago

+1

aowen87 commented 2 years ago

+1

desaixie commented 1 year ago

Please add this feature

mu-arkhipov commented 1 year ago

We are missing this feature!

desaixie commented 1 year ago

I noticed that torch.distributions supports batched probs/logits. Since my MultiDiscrete has the same length for each action dimension (it is actually the factorized, discretized policy from Sampled MuZero), it worked for me to use the current version of Categorical with

dist = torch.distributions.Categorical(probs=policy_probs)
action_bins = dist.sample()

, where policy_probs is of size (batch_size, action_dims, num_bins), actions_bins is of size (batch_size, action_dims).

For those who want a Multi-Categorical because you have a different length in each dimension, I guess you can try padding them to the same length, probably with torch.nn.utils.rnn.pad_sequence.

mu-arkhipov commented 1 year ago

From my perspective, the @desaixie's answer provides an elegant solution to the raised problem. As for me, the issue can be closed.