elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

Scholar.Neigbors.KDTree.predict fails when using EXLA as backend #1471

Closed msluszniak closed 6 months ago

msluszniak commented 7 months ago

The following code

digits_url =
  "https://archive.ics.uci.edu/static/public/887/national+health+and+nutrition+health+survey+2013-2014+(nhanes)+age+prediction+subset.zip"

data =
  Req.get!(digits_url).body

{:ok, [{_, data}]} = :zip.extract(data, [:memory])
df_data = DF.load_csv!(data)

x = df_data |> DF.discard(["age_group", "SEQN"]) |> Nx.stack(axis: 1)
y = Nx.stack(S.cast(df_data["age_group"], :category), axis: 1)

knn_model = Scholar.Neighbors.KDTree.fit(x[[0..40, ..]], k: 4)
knn = Scholar.Neighbors.KDTree.predict(knn_model, x[[0..40, ..]], k: 4)

when backend set to EXLA causes the following error

** (RuntimeError) Executable expected parameter 2 of size 1 but got buffer with incompatible size 2
    (exla 0.7.1) lib/exla/executable.ex:123: EXLA.Executable.unwrap!/1
    (exla 0.7.1) lib/exla/executable.ex:19: EXLA.Executable.run/3
    (exla 0.7.1) lib/exla/defn.ex:492: EXLA.Defn.maybe_outfeed/7
    (stdlib 5.1.1) timer.erl:270: :timer.tc/2
    (exla 0.7.1) lib/exla/defn.ex:413: anonymous fn/7 in EXLA.Defn.__compile__/4
    (nx 0.7.1) lib/nx/defn.ex:452: Nx.Defn.do_jit_apply/3
    (nx 0.7.1) lib/nx/defn/evaluator.ex:441: Nx.Defn.Evaluator.eval_apply/4
    #cell:urdvzl3celu2tjnu:5: (file)
josevalim commented 7 months ago

I have pushed a fix for this to Scholar main and further isolated the bug in the scholar/jv-vectorize-bug branch.