elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.55k stars 190 forks source link

EXLA error for large tensor #1506

Open msluszniak opened 1 month ago

msluszniak commented 1 month ago

The following code

res =
  EXLA.jit(&Scholar.Manifold.MDS.fit(&1, key: &2, num_components: 2)).(
    Nx.iota({1000000, 3}),
    Nx.Random.key(42)
  )

fives an error:

** (RuntimeError) Unable to get dimensions.
    (exla 0.7.2) lib/exla/shape.ex:89: EXLA.Shape.unwrap!/1
    (exla 0.7.2) lib/exla/shape.ex:29: EXLA.Shape.make_shape/2
    (exla 0.7.2) lib/exla/defn.ex:914: EXLA.Defn.to_operator/4
    (exla 0.7.2) lib/exla/defn.ex:898: EXLA.Defn.cached_recur_operator/4
    (exla 0.7.2) lib/exla/defn.ex:657: EXLA.Defn.recur_operator/3
    (exla 0.7.2) lib/exla/defn.ex:2425: EXLA.Defn.recur_composite/4
    (elixir 1.15.5) lib/enum.ex:1819: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
    #cell:lph4otuox3sqx2ec:2: (file)

For smaller tensors like Nx.iota({1000, 3}) the error does not occur.

seanmor5 commented 4 weeks ago

We killed the EXLA.Shape module where this is happening. Does this still occur on main? That error would only fire if there was an issue getting a value from the dimensions tuple. The logic calls enif_get_tuple and then cycles through enif_get_int64 calls. The only way it would fail is if the integer is out of bounds of the type, which it is not - so that's confusing.