ericphanson / UnbalancedOptimalTransport.jl

Sinkhorn divergences for measures of unequal mass
Other
14 stars 1 forks source link

Rust port of some of the code #13

Open ericphanson opened 11 months ago

ericphanson commented 11 months ago

I tried porting unbalanced_sinkhorn! to rust (which I don't know at all) using ChatGPT, just as an experiment[^1]. This is probably not very good or idiomatic rust code, but it does produce the same value as the Julia package. When I tried porting the sinkhorn_divergence code the rust compiler alerted me to the aliasing issue in #11. It also seems to run the example in main a few times faster than my Julia code, but I don't know if it's cheating given the inputs are getting compiled in rather than being runtime inputs.

use ndarray::Array;
use ndarray::{Array2, Axis};
use rand::{Rng, SeedableRng};
use std::time::Instant;

// Simple version of:
// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/src/UnbalancedOptimalTransport.jl#L16-L49
#[derive(Debug)]
struct DiscreteMeasure {
    log_density: Vec<f64>,
    dual_potential: Vec<f64>,
    cache: Vec<f64>,
}

impl DiscreteMeasure {
    fn new(log_density: Vec<f64>) -> Self {
        let n = log_density.len();
        let dual_potential = vec![0.0; n];
        let cache = vec![0.0; n];

        Self {
            log_density,
            dual_potential,
            cache,
        }
    }
}

// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/test/runtests.jl#L11-L26
fn rand_measure(n: usize, scale: f64, seed: u64) -> DiscreteMeasure {
    let mut rng: rand::rngs::StdRng = SeedableRng::seed_from_u64(seed);
    let log_density = (0..n)
        .map(|_| scale * rng.gen::<f64>())
        .collect::<Vec<f64>>();
    DiscreteMeasure::new(log_density)
}

// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/src/divergences.jl#L49
fn approx(ρ: f64, ϵ: f64, x: f64) -> f64 {
    (1.0 / (1.0 + ϵ / ρ)) * x
}

// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/src/utilities.jl#L1-L18
fn logsumexp(w: &mut [f64]) -> f64 {
    let n = w.len();
    let (offset, maxind) = {
        let max_index = w.iter().enumerate().max_by(|&(_, a), &(_, b)| a.partial_cmp(b).unwrap()).unwrap().0;
        (w[max_index], max_index)
    };

    for elem in w.iter_mut() {
        *elem = (*elem - offset).exp();
    }

    let sum_except_max: f64 = {
        w[maxind] -= 1.0;
        let s = w.iter().sum();
        w[maxind] += 1.0;
        s
    };

    (sum_except_max).ln_1p() + offset
}

// https://github.com/ericphanson/UnbalancedOptimalTransport.jl/blob/2573b06fd4ec28c0e219ab4585443060450bf1e0/src/sinkhorn.jl#L13-L105
fn unbalanced_sinkhorn(
    D: f64,
    C: &Array2<f64>,
    a: &mut DiscreteMeasure,
    b: &mut DiscreteMeasure,
    ϵ: f64,
    tol: f64,
    max_iters: usize,
    warn: bool,
) -> (usize, f64) {
    a.dual_potential.iter_mut().for_each(|x| *x = 0.0);
    b.dual_potential.iter_mut().for_each(|x| *x = 0.0);

    let mut max_residual = f64::INFINITY;
    let mut iters = 0;

    let f = &mut a.dual_potential;
    let mut tmp_f = &mut a.cache;
    let g = &mut b.dual_potential;
    let mut tmp_g = &mut b.cache;

    let min_length_a = a.log_density.len().min(tmp_f.len()).min(C.len_of(Axis(0)));
    let min_length_b = b.log_density.len().min(tmp_g.len()).min(C.len_of(Axis(0)));

    while iters < max_iters && max_residual > tol {
        iters += 1;
        max_residual = 0.0;

        for j in 0..g.len() {
            for i in 0..min_length_a {
                tmp_f[i] = a.log_density[i] + (f[i] - C[[i, j]]) / ϵ;
            }
            let new_g = -ϵ * logsumexp(&mut tmp_f);
            let new_g = -approx(D, ϵ, -new_g);
            let diff = (g[j] - new_g).abs();
            if diff > max_residual {
                max_residual = diff;
            }
            g[j] = new_g;
        }

        for j in 0..f.len() {
            for i in 0..min_length_b {
                tmp_g[i] = b.log_density[i] + (g[i] - C[[j, i]]) / ϵ;
            }
            let new_f = -ϵ * logsumexp(&mut tmp_g);
            let new_f = -approx(D, ϵ, -new_f);
            let diff = (f[j] - new_f).abs();
            if diff > max_residual {
                max_residual = diff;
            }
            f[j] = new_f;
        }
    }

    if warn && iters == max_iters {
        println!("Maximum iterations ({}) reached", max_iters);
    }

    (iters, max_residual)
}

fn main() {
    let n = 5; // Define the dimension n
    let scale = 10.0;

    let seed_a = 1;
    let seed_b = 2;

    // Generate inputs a and b using rand_measure
    let mut a = rand_measure(n, scale, seed_a);
    let mut b = rand_measure(n, scale, seed_b);

    // Create a sample cost matrix C (for demonstration purposes)
    let C = Array::from_shape_fn((n, n), |(i, j)| (i + j) as f64);

    println!("Input a: {:?}", a);
    println!("Input b: {:?}", b);
    println!("Input C: {:?}", C);

    let start_time = Instant::now();

    // Run the unbalanced_sinkhorn algorithm
    let (iters, max_residual) =
        unbalanced_sinkhorn(1.0, &C, &mut a, &mut b, 1e-1, 1e-5, 10_000, true);

    let elapsed_time = start_time.elapsed();

    println!("Elapsed time: {:?}", elapsed_time);

    println!("Iterations: {}", iters);
    println!("Max residual: {}", max_residual);

    println!("a.dual_potential: {:?}", a.dual_potential);
    println!("b.dual_potential: {:?}", b.dual_potential);
}

[^1]: It does seem like a useful tool for code translation, but I think I would've been better off learning a bit more rust syntax/concepts first, because I feel like I managed to get it running without getting a good understanding the language or code.