elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

Nx.LinAlg.qr does not support vectorization #1556

Closed jyc closed 2 weeks ago

jyc commented 2 weeks ago

Hello! Thanks again for making Nx. Sorry I haven't been able to try out your CPU-based implementation of LU yet from https://github.com/elixir-nx/nx/issues/1388#issuecomment-2442944169!

I noticed today that QR does not seem to work with vectorized inputs in EXLA:

input =
  Nx.tensor([
    [
      [1, 0],
      [0, 1]
    ],
    [
      [0, 1],
      [1, 0]
    ]
  ])
  |> Nx.vectorize(:foo)

Nx.LinAlg.qr(input)

The result is incorrect; all results except for the first are zero.

{#Nx.Tensor<
   vectorized[foo: 2]
   f32[2][2]
   EXLA.Backend<host:0, 0.3767501109.3804364811.204054>
   [
     [
       [1.0, 0.0],
       [0.0, 1.0]
     ],
     [
       [0.0, 0.0],
       [0.0, 0.0]
     ]
   ]
 >,
 #Nx.Tensor<
   vectorized[foo: 2]
   f32[2][2]
   EXLA.Backend<host:0, 0.3767501109.3804364811.204055>
   [
     [
       [1.0, 0.0],
       [0.0, 1.0]
     ],
     [
       [0.0, 0.0],
       [0.0, 0.0]
     ]
   ]
 >}

Also, the process my Livebook is attached to exits with "Abort trap: 6" shortly afterwards! This happens if I execute Nx.LinAlg.qr in a Livebook session or through IEx.

Is there anything I can do to help debug this? Thanks again.

polvalente commented 2 weeks ago

Which version of Nx are you using? Please try the main branch, as I ran into a bug when implementing LU that was also present in QR that could be account for this.

I'm getting different results in main than 0.9.1, and main reconstructs input correctly :)

polvalente commented 2 weeks ago

The bug was due to "overstriding" the pointer arithmetic. In C++ the size of the datatype is automatically multiplied to the right operand, and I was multiplying explicitly. This might be the cause of the exit you're seeing too

jyc commented 2 weeks ago

Interesting, thanks! I was on 0.8 and just verified that it doesn't work on 0.9.1 either, like you said. I was going to try main, but using:

      {:nx,
       git: "https://github.com/elixir-nx/nx.git", tag: "9e2cd048de610151b85a27a183035bc0873fa77f"},
      {:exla,
       git: "https://github.com/elixir-nx/nx.git", tag: "9e2cd048de610151b85a27a183035bc0873fa77f"},

in mix.exs gave me warnings; I got:

    warning: redefining module NxRoot (current version defined in memory)
    │
  2 │ defmodule NxRoot do
    │ ~~~~~~~~~~~~~~~~~~~
    │
    └─ /Users/jyc/projects/foo/server/deps/exla/mix.exs:2: NxRoot (module)

I imagine this is because Nx and EXLA are both defined in this repository. What would be the correct way to set Mix deps?

EDIT: The full error is:

Erlang/OTP 27 [erts-15.0] [source] [64-bit] [smp:10:10] [ds:10:10:10] [async-threads:1] [jit]

    warning: redefining module NxRoot (current version defined in memory)
    │
  2 │ defmodule NxRoot do
    │ ~~~~~~~~~~~~~~~~~~~
    │
    └─ /Users/jyc/projects/foo/server/deps/exla/mix.exs:2: NxRoot (module)

** (Mix) App nx lists itself as a dependency
jyc commented 2 weeks ago

I think you're right that this works on main!

I messed around with options and this in mix.exs seems to work, although I have no idea whether it's correct:

      {:nx, git: "https://github.com/elixir-nx/nx.git", sparse: "nx", ref: "main", override: true},                                                                              
      {:exla, git: "https://github.com/elixir-nx/nx.git", sparse: "exla", ref: "main"},   

The output is now:

{#Nx.Tensor<
   vectorized[foo: 2]
   f32[2][2]
   EXLA.Backend<host:0, 0.1453107890.856555539.251335>
   [
     [
       [1.0, 0.0],
       [0.0, 1.0]
     ],
     [
       [0.0, -1.0],
       [-1.0, 0.0]
     ]
   ]
 >,
 #Nx.Tensor<
   vectorized[foo: 2]
   f32[2][2]
   EXLA.Backend<host:0, 0.1453107890.856555539.251336>
   [
     [
       [1.0, 0.0],
       [0.0, 1.0]
     ],
     [
       [-1.0, 0.0],
       [0.0, -1.0]
     ]
   ]
 >}
jyc commented 2 weeks ago

Thanks so much for the help! For my own edification, is 7af065eb2b7819b65af54ce0378beb76581b034f the commit that contains the fix you mentioned?

polvalente commented 2 weeks ago

Glad it worked! Yes, that's the commit with the fix :)

Also, for the future, you can use a shorter notation for github deps:

{:nx, github: "elixir-nx/nx", branch: "main", sparse: "nx"}