rust-ml / linfa

A Rust machine learning framework.
Apache License 2.0
3.77k stars 252 forks source link

How to do clustering grid search with multiple CPUs / GPUs? #337

Open jamesaphoenix opened 8 months ago

jamesaphoenix commented 8 months ago

Currently i'm building a wasm project that will expose some clustering functionality to the browser.

Questions:

I'm looking to use all of these: https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/dbscan.rs https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/kmeans.rs https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/optics.rs

Thanks in advance, and great package btw!

jamesaphoenix commented 8 months ago

Here is my current library to provide some context

use linfa::traits::Fit;
use linfa::traits::Predict;
use linfa::DatasetBase;
use linfa_clustering::KMeans;
use linfa_nn::distance::LInfDist;
use ndarray::Array2;
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use serde::{Deserialize, Serialize};
use serde_json;
use wasm_bindgen::prelude::*;

// Data types:
#[derive(Serialize, Deserialize)]
struct Embedding {
    keyword: String,
    embeddings: Vec<f64>,
}

#[derive(Serialize, Deserialize)]
struct EnrichedEmbedding {
    embedding: Embedding,
    cluster: usize,
    is_main_keyword_in_cluster: bool,
}

#[wasm_bindgen]
extern "C" {
    #[wasm_bindgen(js_namespace = console)]
    fn log(s: &str);
}

#[wasm_bindgen]
pub fn greet(name: &str) -> String {
    format!("Hello, {}!", name)
}

// TODO - If there are no keywords then raise an error:

#[wasm_bindgen]
pub fn cluster_embeddings_with_kmeans(
    json_embeddings: &str,
    n_clusters: usize,
) -> Result<String, JsValue> {
    let rng = Xoshiro256Plus::seed_from_u64(42);

    // Deserialize JSON embeddings:
    let embeddings: Vec<Embedding> =
        serde_json::from_str(json_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))?;

    println!("Number of embeddings: {}", embeddings.len());

    // If there are more than 100,000 embeddings:
    if embeddings.len() > 100000 {
        return Err(JsValue::from_str(
            "The number of embeddings is too large. Please use a smaller dataset.",
        ));
    }

    if embeddings.len() == 0 {
        return Err(JsValue::from_str(
            "The number of embeddings is 0. Please provide some embeddings.",
        ));
    }

    // Convert embeddings to ndarray
    let rows = embeddings.len();
    let cols = embeddings[0].embeddings.len();
    let flattened: Vec<f64> = embeddings
        .iter()
        .flat_map(|e| e.embeddings.clone())
        .collect();
    let array = Array2::from_shape_vec((rows, cols), flattened)
        .map_err(|e| JsValue::from_str(&e.to_string()))?;
    let dataset = DatasetBase::from(array);

    log("Clustering embeddings in Rust...");

    // Cluster embeddings in Rust:
    let model = KMeans::params_with(n_clusters, rng, LInfDist)
        .max_n_iterations(1000)
        .fit(&dataset)
        .map_err(|e| JsValue::from_str(&e.to_string()))?;

    log("Finished clustering embeddings in Rust");
    log("Assigning points to clusters...");

    // Assign each point to a cluster using the set of centroids found using `fit`
    let dataset = model.predict(dataset);
    let DatasetBase {
        records, targets, ..
    } = dataset;

    // Assuming you want to correlate the original embeddings with their cluster assignments
    let enriched_embeddings: Vec<EnrichedEmbedding> = embeddings
        .into_iter()
        .zip(targets.iter())
        .map(|(embedding, &cluster)| {
            EnrichedEmbedding {
                embedding,
                cluster: cluster as usize,
                is_main_keyword_in_cluster: false, // Placeholder logic here
            }
        })
        .collect();

    // Serialize the enriched embeddings
    serde_json::to_string(&enriched_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray_rand::rand::rngs::mock;
    use wasm_bindgen_test::*;
    use web_sys::console::assert;

    #[test]
    fn testing_greeting() {
        assert_eq!(greet("world"), "Hello, world!");
    }

    #[wasm_bindgen_test]
    fn test_cluster_embeddings() {
        let mock_json = r#"
            [
                {
                    "keyword": "rust",
                    "embeddings": [0.1, 0.2, 0.3]
                },
                {
                    "keyword": "wasm",
                    "embeddings": [0.4, 0.5, 0.6]
                }
            ]
        "#;

        let n_clusters = 2; // For simplicity, choose a small number of clusters

        // Call the function with the mocked JSON and the number of clusters
        let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);

        // Check that the function succeeded
        assert!(result.is_ok());

        // Deserialize the result to verify its structure
        let enriched_embeddings: Vec<EnrichedEmbedding> =
            serde_json::from_str(&result.unwrap()).unwrap();

        // Verify that each embedding has been assigned a cluster
        assert_eq!(enriched_embeddings.len(), 2);
        for enriched_embedding in enriched_embeddings {
            assert!(enriched_embedding.cluster < n_clusters);
        }
    }

    #[wasm_bindgen_test]
    fn test_cluster_embeddings_with_no_embeddings() {
        let mock_json = r#"
            []
        "#;

        let n_clusters = 2; // For simplicity, choose a small number of clusters
        let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);
        assert!(result.is_err())
    }

    #[wasm_bindgen_test]
    fn test_cluster_embeddings_with_large_dataset() {
        // Mock over 100k embeddings to trigger an error:
        let mock_json = r#"
                    {
                    "keyword": "rust",
                    "embeddings": [0.1, 0.2, 0.3]
                },
                {
                    "keyword": "wasm",
                    "embeddings": [0.4, 0.5, 0.6]
                }
            "#;

        // Now make the mock_json a string of 100k embeddings:
        let mut mock_json_new = String::from("[");
        for _ in 0..100000 {
            mock_json_new.push_str(&mock_json);
        }
        mock_json_new.push_str("]");
        let n_clusters = 2; // For simplicity, choose a small number of clusters

        // Call the function with the mocked JSON and the number of clusters
        let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);

        // Check that the function failed:
        assert!(result.is_err());

        #[wasm_bindgen_test]
        fn test_cluster_embeddings_with_3k_embeddings() {
            let mut mock_json_new = String::from("[");
            let single_embedding = r#"{"keyword": "rust", "embeddings": [0.1, 0.2, 0.3]}"#;
            for i in 0..3000 {
                if i > 0 {
                    mock_json_new.push(',');
                }
                mock_json_new.push_str(single_embedding);
            }
            mock_json_new.push(']');

            let n_clusters = 2; // For simplicity, choose a small number of clusters

            // Call the function with the mocked JSON and the number of clusters
            let result = cluster_embeddings_with_kmeans(&mock_json_new, n_clusters);
            assert!(result.is_ok());
        }

        // Call the function with the mocked JSON and the number of clusters
    }
}
jamesaphoenix commented 7 months ago

Bump on this?