pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Swapping 2 columns in a 2d tensor #1142

Open Kreativshikkk opened 3 months ago

Kreativshikkk commented 3 months ago

I have a function tridiagonalization to tridiagonalize matrix (2d tensor), and I want to map it to batch. It involves a for loop and on each iteration a permutation of 2 columns and 2 rows inside it. I do not understand how to permute 2 columns without errors. So my code for rows works and looks as follows:

row_temp = matrix_stacked[pivot[None]][0]
matrix_stacked[[pivot[None]][0]] = matrix_stacked[i+1].clone()
matrix_stacked[i+1] = row_temp

Where pivot is a tensor and i is a Python integer variable. For columns I have something like this:

column_temp = matrix_stacked[:, [pivot[None]][0]]
matrix_stacked[:, [pivot[None]][0]] = matrix_stacked[:, [i+1]].clone()
matrix_stacked[:, i+1] = column_temp

It does not wotk because of issues with size. What should I do in order to permute i+1 and pivot columns?