I have a function tridiagonalization to tridiagonalize matrix (2d tensor), and I want to map it to batch. It involves a for loop and on each iteration a permutation of 2 columns and 2 rows inside it. I do not understand how to permute 2 columns without errors. So my code for rows works and looks as follows:
I have a function
tridiagonalization
to tridiagonalize matrix (2d tensor), and I want to map it to batch. It involves a for loop and on each iteration a permutation of 2 columns and 2 rows inside it. I do not understand how to permute 2 columns without errors. So my code for rows works and looks as follows:Where
pivot
is a tensor andi
is a Python integer variable. For columns I have something like this:It does not wotk because of issues with size. What should I do in order to permute
i+1
andpivot
columns?