tum-pbs / PhiFlow

A differentiable PDE solving framework for machine learning
MIT License
1.39k stars 189 forks source link

Question regarding tensor transpose #150

Closed Dxyk closed 7 months ago

Dxyk commented 7 months ago

Hello,

I was experimenting with tensor transposes and found that it did not work the way I expected. Maybe I misunderstood something, but a bit of help is appreciated :) I'm currently on PhiFlow version 2.5.3, and PhiML version 1.2.1.

from phi.torch.flow import *

test_tensor = tensor([[0, 1], [2, 3]], spatial('x, y'))
print(test_tensor)                          # 0; 1; 2; 3 (xˢ=2, yˢ=2)
print(math.transpose(test_tensor, 'x, y'))  # 0; 1; 2; 3 (xˢ=2, yˢ=2)
print(math.transpose(test_tensor, 'y, x'))  # 0; 1; 2; 3 (xˢ=2, yˢ=2)
# For the last line, I expected something like 0; 2; 1; 3 (yˢ=2, xˢ=2)

The corresponding PyTorch behaviour is

import torch
test_tensor = torch.tensor([[0, 1], [2, 3]])
print(test_tensor)
# tensor([[0, 1],
#        [2, 3]])
print(torch.transpose(test_tensor, 1, 0))
# tensor([[0, 2],
#         [1, 3]])

Thanks in advance for helping, and also thanks for the great work!

holl- commented 7 months ago

Hi @Dxyk, this is a bug. However, for all practical applications, you don't have to transpose Φ-Flow tensors, as the dimension order is irrelevant for all functions (except for some plotting functions). To get the transposed PyTorch tensor back, you can use test_tensor.native('y,x').

Dxyk commented 7 months ago

Hi, thanks for the reply. I have one more follow-up question - When visualizing a Field, how do I rotate the field by 90 degrees?

Specifically, I have a 3D field with spatial(x, y, z). I would like to visualize the 2D field summed along the x-axis, so I do

vis.plot(math.sum(my_field.values, dim='x'), title="side view")

This gives me a plot with a horizontal y-axis and a verticle z-axis. How do I rotate/transpose the plot so that it gives me a horizontal z-axis and a vertical y-axis? Since transpose doesn't work in this case, is there another workaround?

holl- commented 7 months ago

As a workaround, you could simply define

def transpose(x, order='y,x'):
    return wrap(x.native(order), x.shape[order])
Dxyk commented 7 months ago

Perfect, that's exactly what I was looking for. Thanks for the help!