elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.59k stars 189 forks source link

Incorrect results for Nx.LinAlg.svd #949

Closed msluszniak closed 1 year ago

msluszniak commented 1 year ago

For the following input

Nx.tensor([
            [
              0.33747746329154127,
              0.8448716981481403,
              0.8465694977489985,
              0.1926030183919778,
              0.04625514004748599,
              0.7502944398364579
            ],
            [
              0.4446625676763911,
              0.8463476875150429,
              0.39503640704174303,
              0.7910270477615085,
              0.8722376324896636,
              0.6758646358483182
            ],
            [
              0.6154292118141929,
              0.5455230739505744,
              0.9565376231248434,
              0.2790218491103198,
              0.5663205639536116,
              0.29588894254993525
            ],
            [
              0.6873114496145727,
              0.2603452300422152,
              0.5479350062232057,
              0.5267668983186267,
              0.2557562799821602,
              0.4790844622306156
            ],
            [
              0.3298696032797205,
              0.3446971837357009,
              0.2888187784379451,
              0.6165562827943281,
              0.27242014359429534,
              0.0243891670454095
            ],
            [
              0.8073663574129741,
              0.6744673959108053,
              0.24853954732383965,
              0.26991916232511237,
              0.3544102499522487,
              0.8091680144952467
            ]
          ])

svd returns

{#Nx.Tensor<
   f32[6][6]
   [
     [-0.42250970005989075, 0.4181019365787506, 0.0026111374609172344, -0.10452315956354141, -0.7814749479293823, 0.15824458003044128],
     [-0.517398476600647, 0.12293870002031326, 0.5080962777137756, 0.5810934901237488, 0.30394962430000305, 0.1702001541852951],
     [-0.43182945251464844, -0.5853010416030884, 0.39497026801109314, -0.5105135440826416, -0.055890556424856186, -0.2262668013572693],
     [-0.35987240076065063, -0.36977916955947876, -0.5620261430740356, -0.0012213886948302388, 0.12350377440452576, 0.6345291137695312],
     [-0.23809058964252472, -0.28327229619026184, -0.4136335253715515, 0.5061729550361633, -0.21808022260665894, -0.6230625510215759],
     [-0.42557311058044434, 0.5005193948745728, -0.31442612409591675, -0.36683598160743713, 0.4805992543697357, -0.3224279582500458]
   ]
 >,
 #Nx.Tensor<
   f32[6]
   [3.1087610721588135, 0.67035311460495, 0.6646991968154907, 0.5060449242591858, 0.43972793221473694, 0.17170560359954834]
 >,
 #Nx.Tensor<
   f32[6][6]
   [
     [-0.4207117557525635, -0.4737299978733063, -0.7119972109794617, 0.16700628399848938, -0.12663894891738892, 0.21843300759792328],
     [-0.20900508761405945, -0.521405816078186, 0.6500899791717529, 0.46227264404296875, -0.17707623541355133, 0.12955260276794434],
     [-0.3757129907608032, 0.008483347482979298, -0.03162355348467827, 0.2769175171852112, 0.4845382273197174, -0.7391259074211121],
     [0.5079885125160217, -0.483378529548645, -0.0526423342525959, -0.028192348778247833, 0.6876038908958435, 0.17868347465991974],
     [-0.5295675992965698, 0.3443851172924042, 0.16036741435527802, -0.047520797699689865, 0.49480050802230835, 0.5728468894958496],
     [-0.3156961500644684, -0.3890860974788666, 0.20237109065055847, -0.8238182663917542, -0.014236591756343842, -0.17063017189502716]
   ]
 >}

But if we now try to reconstruct the original matrix

{u, s, vt} = Nx.LinAlg.svd(a)
{s_size} = Nx.shape(s)

Nx.broadcast(0.0, {s_size, s_size})
|> Nx.put_diagonal(s)
|> then(&Nx.dot(u, &1)
|> Nx.dot(vt)

we get

#Nx.Tensor<
  f32[6][6]
  [
    [0.6398975253105164, 0.3727651536464691, 1.070520043373108, -0.09387847036123276, -0.08923928439617157, -0.46281859278678894],
    [0.601960301399231, 0.6143905520439148, 1.1999882459640503, -0.17572081089019775, 0.6206580996513367, -0.4661707580089569],
    [0.4421979784965515, 0.9742961525917053, 0.6942494511604309, -0.29241636395454407, 0.1774483621120453, -0.5917259454727173],
    [0.5993717908859253, 0.6326770186424255, 0.6780105233192444, -0.4971994161605835, 0.029454946517944336, 0.012043285183608532],
    [0.669061541557312, 0.33210280537605286, 0.36173108220100403, -0.2020602971315384, 0.12434040755033493, 0.026028713211417198],
    [0.37626299262046814, 0.6340826153755188, 1.1991658210754395, -0.08292210102081299, -0.015425268560647964, 0.006294505670666695]
  ]
>

Which does not correspond to the original one. On the other hand for NumPy svd we get

u = [[-0.42250955 -0.69939226 -0.0676828  -0.42475942  0.3837914   0.00623618]
 [-0.51739854  0.5633327  -0.21903265 -0.49597746 -0.13577256  0.3202372 ]
 [-0.4318294  -0.16344908  0.6483023   0.06072902 -0.5897151  -0.122719  ]
 [-0.3598724   0.03894597  0.16080056  0.614588    0.37407485  0.57049775]
 [-0.2380906   0.39848185  0.31266332  0.00907759  0.5394595  -0.6290212 ]
 [-0.4255731  -0.08054024 -0.6352423   0.4382886  -0.23570609 -0.40151447]]
s = [3.1087604  0.8369838  0.69383204 0.5481634  0.38850585 0.20046385]
vt = [[-0.42071193 -0.48033065 -0.4332462  -0.341737   -0.32910848 -0.42365676]
 [ 0.00843725 -0.13155974 -0.48923385  0.6090474   0.5453104  -0.27380434]
 [-0.02949977 -0.24171527  0.71606654  0.14500445  0.10684541 -0.628896  ]
 [ 0.825944   -0.5231329  -0.08960789 -0.01742733 -0.18767026  0.02439201]
 [-0.12618549  0.03089574  0.02414105  0.68985325 -0.70924675  0.06009421]
 [-0.3520702  -0.6472448   0.22712195  0.12267569  0.21132278  0.5880729 ]]

which is correct decomposition (u @ np.diag(s) @ vt gives the original matrix).

for u the first column is correct, for s the first value, and for vt only the vt[[0, 0]].

seanmor5 commented 1 year ago

Is this the case on all backends or just binary-backend?

msluszniak commented 1 year ago

I've checked the EXLA and Trochx, and they seem to have the same incorrect result.

polvalente commented 1 year ago

@msluszniak Torchx returned the correct result for me. I used: u |> Nx.dot(Nx.make_diagonal(s)) |> Nx.dot(vt) to reconstruct the matrix

msluszniak commented 1 year ago

@polvalente yes I agree, I messed up something with the backend initially, but now I got the correct result for Torchx.

polvalente commented 1 year ago

Ok, so there seems to be two problems happening.

1: In the BinaryBackend, the implementation falls into some sort of fixed point because there are two singular values that start close to each other (BinaryBackend converges them to 0.67 and 0.66; actual values, returned by EXLA as well as Jax, are 0.83 and 0.69). This indicates poorly conditioned matrix for which the BinaryBackend isn't fully prepared to handle.

2: EXLA is returning v instead of vt. If you do u |> Nx.multiply(s) |> Nx.dot(vt |> Nx.LinAlg.adjoint()) you get the original tensor back within a tolerance. This should be an easy fix, but I still need to confirm this is the case for all matrices.

msluszniak commented 1 year ago

For this tensor

  [
     [0.8583984375, 0.0224609375, 0.708984375, 0.998046875],
     [0.6455078125, 0.9814453125, 0.390625, 0.984375],
     [0.6728515625, 0.21875, 0.3369140625, 0.7587890625],
     [0.8095703125, 0.8837890625, 0.2412109375, 0.7744140625]
  ]

We get

{#Nx.Tensor<
   f32[4][4]
   [
     [-0.5039474368095398, 0.7379504442214966, 0.15171891450881958, 0.42243048548698425],
     [-0.5674749612808228, -0.45060065388679504, 0.676255464553833, -0.13270191848278046],
     [-0.3948972523212433, 0.2652057111263275, -0.3157660663127899, -0.8209834694862366],
     [-0.5177502632141113, -0.42667800188064575, -0.6480368375778198, 0.3604564070701599]
   ]
 >,
 #Nx.Tensor<
   f32[4]
   [2.6915104389190674, 0.8506444096565247, 0.21740709245204926, 0.08219550549983978]
 >,
 #Nx.Tensor<
   f32[4][4]
   [
     [-0.5512740015983582, -0.3966410458087921, -0.7097650170326233, -0.18710029125213623],
     [0.2064405381679535, -0.37330538034439087, 0.2753744423389435, -0.861506998538971],
     [-0.7834685444831848, 0.3717951476573944, 0.4544169008731842, -0.2035943865776062],
     [0.19913889467716217, 0.7517229318618774, -0.46250173449516296, -0.4258503019809723]
   ]
 >}

The U matrix and singular values are correct, but vt has only the first column correct.

Correct results for reference

[[-0.50394743  0.73795044  0.15171892  0.42243048]
 [-0.56747496 -0.45060066  0.67625544 -0.13270191]
 [-0.39489726  0.26520572 -0.31576606 -0.82098348]
 [-0.51775024 -0.42667801 -0.64803684  0.36045639]]
[2.69151034 0.85064443 0.2174071  0.08219551]
[[-0.551274   -0.41323659 -0.31093851 -0.65471348]
 [ 0.20644053 -0.87550618  0.39218688  0.19251152]
 [-0.78346856  0.11643493  0.50149706  0.34802387]
 [ 0.19913889  0.22173713  0.70569639 -0.6427822 ]]

So I think there is a more general problem with BinaryBackend.

polvalente commented 1 year ago

Yes, there's definitely something to fix for the BinaryBackend. It's strange that it's only a problem with vt and the returned vt is orthogonal still.

991 fixes EXLA, at least

polvalente commented 1 year ago

@msluszniak After the merge of #1015 and #1022, we now have the correct results (up to a rounding error):

iex(2)> {u, s, vh} = Nx.LinAlg.svd(t); u |> Nx.dot(Nx.eye(Nx.shape(t)) |> Nx.put_diagonal(s)) |> Nx.dot(vh)
#Nx.Tensor<
  f32[4][4]
  [
    [0.8645899295806885, 0.021466584876179695, 0.7096304297447205, 1.0008362531661987],
    [0.6423190236091614, 0.9786409139633179, 0.3899857699871063, 0.9842138290405273],
    [0.6740410327911377, 0.21641786396503448, 0.3399522006511688, 0.7606345415115356],
    [0.8090260624885559, 0.8779084086418152, 0.24401454627513885, 0.7728657722473145]
  ]
>
iex(3)> t                                                                                                  
#Nx.Tensor<
  f32[4][4]
  [
    [0.8583984375, 0.0224609375, 0.708984375, 0.998046875],
    [0.6455078125, 0.9814453125, 0.390625, 0.984375],
    [0.6728515625, 0.21875, 0.3369140625, 0.7587890625],
    [0.8095703125, 0.8837890625, 0.2412109375, 0.7744140625]
  ]
>
{#Nx.Tensor<
   f32[4][4]
   [
     [0.5066919922828674, -0.7352122068405151, -0.14841154217720032, 0.4251161515712738],
     [0.5657464265823364, 0.4564741551876068, -0.6759865283966064, -0.1207776740193367],
     [0.39590415358543396, -0.265470415353775, 0.299826979637146, -0.8263956904411316],
     [0.516191840171814, 0.4249938130378723, 0.6566023826599121, 0.34893032908439636]
   ]
 >,
 #Nx.Tensor<
   f32[4]
   [2.6914844512939453, 0.8506505489349365, 0.2173856943845749, 0.08223722130060196]
 >,
 #Nx.Tensor<
   f32[4][4]
   [
     [0.5520889163017273, 0.4099556505680084, 0.31237152218818665, 0.6554068922996521],
     [-0.20873655378818512, 0.877675473690033, -0.38823550939559937, -0.18811701238155365],
     [0.7856599688529968, -0.10767427086830139, -0.491233766078949, -0.3603341579437256],
     [0.18530665338039398, 0.22368313372135162, 0.7144758701324463, -0.636509895324707]
   ]
 >}

Note that signs are swapped for u and vh, but that's not an error by itself because the change cancels out when reconstructing the matrix.