elixir-nx / scholar

Traditional machine learning on top of Nx
Apache License 2.0
408 stars 43 forks source link

Hierarchical clustering improvements #213

Open josevalim opened 9 months ago

josevalim commented 9 months ago
josevalim commented 8 months ago

Here is a small patch for the first one, but unfortunately it is not enough, so perhaps something else is wrong:

diff --git a/lib/scholar/cluster/hierarchical.ex b/lib/scholar/cluster/hierarchical.ex
index 481b6cd..562367f 100644
--- a/lib/scholar/cluster/hierarchical.ex
+++ b/lib/scholar/cluster/hierarchical.ex
@@ -196,10 +196,11 @@ defmodule Scholar.Cluster.Hierarchical do
     clades = Nx.broadcast(-1, {n - 1, 2})
     sizes = Nx.broadcast(1, {2 * n - 1})
     pointers = Nx.broadcast(-1, {2 * n - 2})
+    n_sizes = Nx.broadcast(1, {n})
     diss = Nx.tensor(:infinity, type: Nx.type(pairwise)) |> Nx.broadcast({n - 1})

-    {{clades, diss, sizes}, _} =
-      while {{clades, diss, sizes}, {count = 0, pointers, pairwise}}, count < n - 1 do
+    {{clades, diss, sizes, n_sizes}, _} =
+      while {{clades, diss, sizes, n_sizes}, {count = 0, pointers, pairwise}}, count < n - 1 do
         # Indexes of who I am nearest to
         nearest = Nx.argmin(pairwise, axis: 1)

@@ -213,10 +214,21 @@ defmodule Scholar.Cluster.Hierarchical do
         # They are bidirectional but let's keep only one side.
         links = Nx.select(clades_selector and nearest > nearest_of_nearest, nearest, n)

-        {clades, count, pointers, pairwise, diss, sizes} =
-          merge_clades(clades, count, pointers, pairwise, diss, sizes, links, n, update_fun)
-
-        {{clades, diss, sizes}, {count, pointers, pairwise}}
+        {clades, count, pointers, pairwise, diss, sizes, n_sizes} =
+          merge_clades(
+            clades,
+            count,
+            pointers,
+            pairwise,
+            diss,
+            sizes,
+            n_sizes,
+            links,
+            n,
+            update_fun
+          )
+
+        {{clades, diss, sizes, n_sizes}, {count, pointers, pairwise}}
       end

     sizes = sizes[n..(2 * n - 2)]
@@ -224,16 +236,27 @@ defmodule Scholar.Cluster.Hierarchical do
     {clades[perm], diss[perm], sizes[perm]}
   end

-  defnp merge_clades(clades, count, pointers, pairwise, diss, sizes, links, n, update_fun) do
-    {{clades, count, pointers, pairwise, diss, sizes}, _} =
-      while {{clades, count, pointers, pairwise, diss, sizes}, links},
+  defnp merge_clades(
+          clades,
+          count,
+          pointers,
+          pairwise,
+          diss,
+          sizes,
+          n_sizes,
+          links,
+          n,
+          update_fun
+        ) do
+    {{clades, count, pointers, pairwise, diss, sizes, n_sizes}, _} =
+      while {{clades, count, pointers, pairwise, diss, sizes, n_sizes}, links},
             i <- 0..(Nx.size(links) - 1) do
         # i < j because of how links is formed.
         # i will become the new clade index and we "infinity-out" j.
         j = links[i]

         if j == n do
-          {{clades, count, pointers, pairwise, diss, sizes}, links}
+          {{clades, count, pointers, pairwise, diss, sizes, n_sizes}, links}
         else
           # Clades a and b (i and j of pairwise) are being merged into c.
           indices = [i, j] |> Nx.stack() |> Nx.new_axis(-1)
@@ -251,6 +274,9 @@ defmodule Scholar.Cluster.Hierarchical do
           sc = sa + sb
           sizes = Nx.indexed_put(sizes, Nx.stack([i, c]) |> Nx.new_axis(-1), Nx.stack([sc, sc]))

+          n_sizes =
+            Nx.indexed_put(n_sizes, Nx.stack([i, j]) |> Nx.new_axis(-1), Nx.stack([sc, sc]))
+
           # Update dissimilarities
           diss = Nx.indexed_put(diss, Nx.stack([count]), pairwise[i][j])

@@ -259,7 +285,7 @@ defmodule Scholar.Cluster.Hierarchical do

           # Update pairwise
           updates =
-            update_fun.(pairwise[i], pairwise[j], pairwise[i][j], sa, sb, sc)
+            update_fun.(pairwise[i], pairwise[j], pairwise[i][j], sa, sb, n_sizes)
             |> Nx.indexed_put(indices, Nx.broadcast(:infinity, {2}))

           pairwise =
@@ -269,11 +295,11 @@ defmodule Scholar.Cluster.Hierarchical do
             |> Nx.put_slice([j, 0], Nx.broadcast(:infinity, {1, n}))
             |> Nx.put_slice([0, j], Nx.broadcast(:infinity, {n, 1}))

-          {{clades, count + 1, pointers, pairwise, diss, sizes}, links}
+          {{clades, count + 1, pointers, pairwise, diss, sizes, n_sizes}, links}
         end
       end

-    {clades, count, pointers, pairwise, diss, sizes}
+    {clades, count, pointers, pairwise, diss, sizes, n_sizes}
   end

   defnp find_clade(pointers, i) do
diff --git a/test/scholar/cluster/hierarchical_test.exs b/test/scholar/cluster/hierarchical_test.exs
index 6c4e5d5..4511252 100644
--- a/test/scholar/cluster/hierarchical_test.exs
+++ b/test/scholar/cluster/hierarchical_test.exs
@@ -127,7 +127,6 @@ defmodule Scholar.Cluster.HierarchicalTest do
       assert model.dissimilarities == Nx.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0])
     end

-    @tag :skip
     test "ward", %{data: data} do
       model = Hierarchical.fit(data, linkage: :ward)
josevalim commented 8 months ago

I have commented Ward for now, see 6845727ee7889d085a9d79cec948dcf3c94ed2bc.