pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.5k stars 981 forks source link

FR Transpose transformation #2480

Open akern40 opened 4 years ago

akern40 commented 4 years ago

Issue Description

Many normalizing flows require the use of transpositions during the flow, allowing different transformations to operate on different dimensions. Since transpositions don't change the volume of a tensor, this operation would be relatively easy to implement.

One question of implementation would be how to handle transpositions on batches, i.e.

x = torch.rand(5, 4, 3)

# If we only want to permute the last two dimensions, can we do
transpose = Transpose((0, 1))
transpose(x).shape  # Would show [5, 3, 4]

# Or do we have to fully specify the entire thing
transpose = Transpose((0, 2, 1))
transpose(x).shape  # Would show [5, 3, 4]

While the first method seems a little odd, the advantage is that you can handle both batched and un-batched tensors (the transformation would assume the transposition happens on the rightmost dimensions).

cc @stefanwebb

akern40 commented 4 years ago

After some quick testing, turns out negative indices work with torch.permute, so perhaps the best thing to do would simply be to require negative indices be passed to the transformation:

x = torch.rand(5, 4, 3)

# Using negative dimensions
transpose = Transpose((-1, -2))
transpose(x).shape  # Would show [5, 3, 4]

This still makes an assumption about operating on the rightmost tensor dimensions but could make the assumption more clear to the end-user. The expanded permutation could be achieved by prepending zeros to the given all-negative permutation.