Closed msluszniak closed 1 year ago
Is this the case on all backends or just binary-backend?
I've checked the EXLA
and Trochx
, and they seem to have the same incorrect result.
@msluszniak Torchx returned the correct result for me.
I used: u |> Nx.dot(Nx.make_diagonal(s)) |> Nx.dot(vt)
to reconstruct the matrix
@polvalente yes I agree, I messed up something with the backend initially, but now I got the correct result for Torchx.
Ok, so there seems to be two problems happening.
1: In the BinaryBackend, the implementation falls into some sort of fixed point because there are two singular values that start close to each other (BinaryBackend converges them to 0.67 and 0.66; actual values, returned by EXLA as well as Jax, are 0.83 and 0.69). This indicates poorly conditioned matrix for which the BinaryBackend isn't fully prepared to handle.
2: EXLA is returning v instead of vt. If you do u |> Nx.multiply(s) |> Nx.dot(vt |> Nx.LinAlg.adjoint())
you get the original tensor back within a tolerance. This should be an easy fix, but I still need to confirm this is the case for all matrices.
For this tensor
[
[0.8583984375, 0.0224609375, 0.708984375, 0.998046875],
[0.6455078125, 0.9814453125, 0.390625, 0.984375],
[0.6728515625, 0.21875, 0.3369140625, 0.7587890625],
[0.8095703125, 0.8837890625, 0.2412109375, 0.7744140625]
]
We get
{#Nx.Tensor<
f32[4][4]
[
[-0.5039474368095398, 0.7379504442214966, 0.15171891450881958, 0.42243048548698425],
[-0.5674749612808228, -0.45060065388679504, 0.676255464553833, -0.13270191848278046],
[-0.3948972523212433, 0.2652057111263275, -0.3157660663127899, -0.8209834694862366],
[-0.5177502632141113, -0.42667800188064575, -0.6480368375778198, 0.3604564070701599]
]
>,
#Nx.Tensor<
f32[4]
[2.6915104389190674, 0.8506444096565247, 0.21740709245204926, 0.08219550549983978]
>,
#Nx.Tensor<
f32[4][4]
[
[-0.5512740015983582, -0.3966410458087921, -0.7097650170326233, -0.18710029125213623],
[0.2064405381679535, -0.37330538034439087, 0.2753744423389435, -0.861506998538971],
[-0.7834685444831848, 0.3717951476573944, 0.4544169008731842, -0.2035943865776062],
[0.19913889467716217, 0.7517229318618774, -0.46250173449516296, -0.4258503019809723]
]
>}
The U matrix and singular values are correct, but vt has only the first column correct.
Correct results for reference
[[-0.50394743 0.73795044 0.15171892 0.42243048]
[-0.56747496 -0.45060066 0.67625544 -0.13270191]
[-0.39489726 0.26520572 -0.31576606 -0.82098348]
[-0.51775024 -0.42667801 -0.64803684 0.36045639]]
[2.69151034 0.85064443 0.2174071 0.08219551]
[[-0.551274 -0.41323659 -0.31093851 -0.65471348]
[ 0.20644053 -0.87550618 0.39218688 0.19251152]
[-0.78346856 0.11643493 0.50149706 0.34802387]
[ 0.19913889 0.22173713 0.70569639 -0.6427822 ]]
So I think there is a more general problem with BinaryBackend.
Yes, there's definitely something to fix for the BinaryBackend. It's strange that it's only a problem with vt and the returned vt is orthogonal still.
@msluszniak After the merge of #1015 and #1022, we now have the correct results (up to a rounding error):
iex(2)> {u, s, vh} = Nx.LinAlg.svd(t); u |> Nx.dot(Nx.eye(Nx.shape(t)) |> Nx.put_diagonal(s)) |> Nx.dot(vh)
#Nx.Tensor<
f32[4][4]
[
[0.8645899295806885, 0.021466584876179695, 0.7096304297447205, 1.0008362531661987],
[0.6423190236091614, 0.9786409139633179, 0.3899857699871063, 0.9842138290405273],
[0.6740410327911377, 0.21641786396503448, 0.3399522006511688, 0.7606345415115356],
[0.8090260624885559, 0.8779084086418152, 0.24401454627513885, 0.7728657722473145]
]
>
iex(3)> t
#Nx.Tensor<
f32[4][4]
[
[0.8583984375, 0.0224609375, 0.708984375, 0.998046875],
[0.6455078125, 0.9814453125, 0.390625, 0.984375],
[0.6728515625, 0.21875, 0.3369140625, 0.7587890625],
[0.8095703125, 0.8837890625, 0.2412109375, 0.7744140625]
]
>
{#Nx.Tensor<
f32[4][4]
[
[0.5066919922828674, -0.7352122068405151, -0.14841154217720032, 0.4251161515712738],
[0.5657464265823364, 0.4564741551876068, -0.6759865283966064, -0.1207776740193367],
[0.39590415358543396, -0.265470415353775, 0.299826979637146, -0.8263956904411316],
[0.516191840171814, 0.4249938130378723, 0.6566023826599121, 0.34893032908439636]
]
>,
#Nx.Tensor<
f32[4]
[2.6914844512939453, 0.8506505489349365, 0.2173856943845749, 0.08223722130060196]
>,
#Nx.Tensor<
f32[4][4]
[
[0.5520889163017273, 0.4099556505680084, 0.31237152218818665, 0.6554068922996521],
[-0.20873655378818512, 0.877675473690033, -0.38823550939559937, -0.18811701238155365],
[0.7856599688529968, -0.10767427086830139, -0.491233766078949, -0.3603341579437256],
[0.18530665338039398, 0.22368313372135162, 0.7144758701324463, -0.636509895324707]
]
>}
Note that signs are swapped for u
and vh
, but that's not an error by itself because the change cancels out when reconstructing the matrix.
For the following input
svd returns
But if we now try to reconstruct the original matrix
we get
Which does not correspond to the original one. On the other hand for NumPy svd we get
which is correct decomposition (
u @ np.diag(s) @ vt
gives the original matrix).for u the first column is correct, for s the first value, and for vt only the vt[[0, 0]].