Open jamesaphoenix opened 8 months ago
Here is my current library to provide some context
use linfa::traits::Fit;
use linfa::traits::Predict;
use linfa::DatasetBase;
use linfa_clustering::KMeans;
use linfa_nn::distance::LInfDist;
use ndarray::Array2;
use ndarray_rand::rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use serde::{Deserialize, Serialize};
use serde_json;
use wasm_bindgen::prelude::*;
// Data types:
#[derive(Serialize, Deserialize)]
struct Embedding {
keyword: String,
embeddings: Vec<f64>,
}
#[derive(Serialize, Deserialize)]
struct EnrichedEmbedding {
embedding: Embedding,
cluster: usize,
is_main_keyword_in_cluster: bool,
}
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console)]
fn log(s: &str);
}
#[wasm_bindgen]
pub fn greet(name: &str) -> String {
format!("Hello, {}!", name)
}
// TODO - If there are no keywords then raise an error:
#[wasm_bindgen]
pub fn cluster_embeddings_with_kmeans(
json_embeddings: &str,
n_clusters: usize,
) -> Result<String, JsValue> {
let rng = Xoshiro256Plus::seed_from_u64(42);
// Deserialize JSON embeddings:
let embeddings: Vec<Embedding> =
serde_json::from_str(json_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))?;
println!("Number of embeddings: {}", embeddings.len());
// If there are more than 100,000 embeddings:
if embeddings.len() > 100000 {
return Err(JsValue::from_str(
"The number of embeddings is too large. Please use a smaller dataset.",
));
}
if embeddings.len() == 0 {
return Err(JsValue::from_str(
"The number of embeddings is 0. Please provide some embeddings.",
));
}
// Convert embeddings to ndarray
let rows = embeddings.len();
let cols = embeddings[0].embeddings.len();
let flattened: Vec<f64> = embeddings
.iter()
.flat_map(|e| e.embeddings.clone())
.collect();
let array = Array2::from_shape_vec((rows, cols), flattened)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
let dataset = DatasetBase::from(array);
log("Clustering embeddings in Rust...");
// Cluster embeddings in Rust:
let model = KMeans::params_with(n_clusters, rng, LInfDist)
.max_n_iterations(1000)
.fit(&dataset)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
log("Finished clustering embeddings in Rust");
log("Assigning points to clusters...");
// Assign each point to a cluster using the set of centroids found using `fit`
let dataset = model.predict(dataset);
let DatasetBase {
records, targets, ..
} = dataset;
// Assuming you want to correlate the original embeddings with their cluster assignments
let enriched_embeddings: Vec<EnrichedEmbedding> = embeddings
.into_iter()
.zip(targets.iter())
.map(|(embedding, &cluster)| {
EnrichedEmbedding {
embedding,
cluster: cluster as usize,
is_main_keyword_in_cluster: false, // Placeholder logic here
}
})
.collect();
// Serialize the enriched embeddings
serde_json::to_string(&enriched_embeddings).map_err(|e| JsValue::from_str(&e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray_rand::rand::rngs::mock;
use wasm_bindgen_test::*;
use web_sys::console::assert;
#[test]
fn testing_greeting() {
assert_eq!(greet("world"), "Hello, world!");
}
#[wasm_bindgen_test]
fn test_cluster_embeddings() {
let mock_json = r#"
[
{
"keyword": "rust",
"embeddings": [0.1, 0.2, 0.3]
},
{
"keyword": "wasm",
"embeddings": [0.4, 0.5, 0.6]
}
]
"#;
let n_clusters = 2; // For simplicity, choose a small number of clusters
// Call the function with the mocked JSON and the number of clusters
let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);
// Check that the function succeeded
assert!(result.is_ok());
// Deserialize the result to verify its structure
let enriched_embeddings: Vec<EnrichedEmbedding> =
serde_json::from_str(&result.unwrap()).unwrap();
// Verify that each embedding has been assigned a cluster
assert_eq!(enriched_embeddings.len(), 2);
for enriched_embedding in enriched_embeddings {
assert!(enriched_embedding.cluster < n_clusters);
}
}
#[wasm_bindgen_test]
fn test_cluster_embeddings_with_no_embeddings() {
let mock_json = r#"
[]
"#;
let n_clusters = 2; // For simplicity, choose a small number of clusters
let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);
assert!(result.is_err())
}
#[wasm_bindgen_test]
fn test_cluster_embeddings_with_large_dataset() {
// Mock over 100k embeddings to trigger an error:
let mock_json = r#"
{
"keyword": "rust",
"embeddings": [0.1, 0.2, 0.3]
},
{
"keyword": "wasm",
"embeddings": [0.4, 0.5, 0.6]
}
"#;
// Now make the mock_json a string of 100k embeddings:
let mut mock_json_new = String::from("[");
for _ in 0..100000 {
mock_json_new.push_str(&mock_json);
}
mock_json_new.push_str("]");
let n_clusters = 2; // For simplicity, choose a small number of clusters
// Call the function with the mocked JSON and the number of clusters
let result = cluster_embeddings_with_kmeans(mock_json, n_clusters);
// Check that the function failed:
assert!(result.is_err());
#[wasm_bindgen_test]
fn test_cluster_embeddings_with_3k_embeddings() {
let mut mock_json_new = String::from("[");
let single_embedding = r#"{"keyword": "rust", "embeddings": [0.1, 0.2, 0.3]}"#;
for i in 0..3000 {
if i > 0 {
mock_json_new.push(',');
}
mock_json_new.push_str(single_embedding);
}
mock_json_new.push(']');
let n_clusters = 2; // For simplicity, choose a small number of clusters
// Call the function with the mocked JSON and the number of clusters
let result = cluster_embeddings_with_kmeans(&mock_json_new, n_clusters);
assert!(result.is_ok());
}
// Call the function with the mocked JSON and the number of clusters
}
}
Bump on this?
Currently i'm building a wasm project that will expose some clustering functionality to the browser.
Questions:
I'm looking to use all of these: https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/dbscan.rs https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/kmeans.rs https://github.com/rust-ml/linfa/blob/master/algorithms/linfa-clustering/examples/optics.rs
Thanks in advance, and great package btw!