abess-team / skscope

skscope: Sparse-Constrained OPtimization via itErative-solvers
https://skscope.readthedocs.io
MIT License
312 stars 13 forks source link

refactor: optim loss of MultivariateFailure #75

Closed bbayukari closed 8 months ago

bbayukari commented 8 months ago

skmodel.MultivariateFailure

test code

import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import time

def multivariate_failure_objective(params, X, y, delta, n, K):
    Xbeta = jnp.matmul(X, params)
    tmp = jnp.ones((n, K))
    for i in range(n):
        for k in range(K):
            tmp = tmp.at[i, k].set(Xbeta[i] - jnp.log(jnp.matmul(y[:, k] >= y[i, k], jnp.exp(Xbeta))))
    loss = -jnp.mean(tmp * delta)
    return loss

def multivariate_failure_objective_vectorized_logsumexp(params, X, y, delta, n, K):
    Xbeta_expanded = jnp.matmul(X, params)[:, None]
    sum_exp_Xbeta = logsumexp(Xbeta_expanded + jnp.log(y >= y[:, None, :]), axis=1)
    loss = -jnp.mean((Xbeta_expanded - sum_exp_Xbeta) * delta)
    return loss

def make_Clayton2_data(n, theta=15, lambda1=1, lambda2=1, c1=1, c2=1):
    u1 = np.random.uniform(0, 1, n)
    u2 = np.random.uniform(0, 1, n)
    time2 = -np.log(1 - u2) / lambda2
    time1 = (np.log(1 - np.power((1 - u2), -theta) + np.power((1 - u1), -theta / (1 + theta)) * np.power((1 - u2), -theta)) / theta / lambda1)
    ctime1 = np.random.uniform(0, c1, n)
    ctime2 = np.random.uniform(0, c2, n)
    delta1 = (time1 < ctime1) * 1
    delta2 = (time2 < ctime2) * 1
    time1 = np.minimum(time1, ctime1)
    time2 = np.minimum(time2, ctime2)
    y = np.hstack((time1.reshape((-1, 1)), time2.reshape((-1, 1))))
    delta = np.hstack((delta1.reshape((-1, 1)), delta2.reshape((-1, 1))))
    return y, delta

def test(seed):
    np.random.seed(seed)
    n, p, s, rho = 100, 100, 10, 0.5
    K = 2

    beta = np.zeros(p)
    beta[:s] = 5
    Sigma = np.power(rho, np.abs(np.linspace(1, p, p) - np.linspace(1, p, p).reshape(p, 1)))
    X = np.random.multivariate_normal(mean=np.zeros(p), cov=Sigma, size=(n,))
    lambda1 = 1 * np.exp(np.matmul(X, beta))
    lambda2 = 10 * np.exp(np.matmul(X, beta))

    y, delta = make_Clayton2_data(n, theta=50, lambda1=lambda1, lambda2=lambda2, c1=5, c2=5)

    # Convert numpy arrays to jax numpy arrays
    X_jax = jnp.array(X)
    y_jax = jnp.array(y)
    delta_jax = jnp.array(delta)

    # Generate random parameters for testing
    params = jnp.array(np.random.randn(p))

    # Calculate loss using both methods
    t1 = time.time()
    loss_original = multivariate_failure_objective(params, X_jax, y_jax, delta_jax, n, K)
    t2 = time.time()
    loss_vectorized = multivariate_failure_objective_vectorized_logsumexp(params, X_jax, y_jax, delta_jax, n, K)
    t3 = time.time()

    return loss_original, t2 - t1, loss_vectorized, t3 - t2

if __name__ == "__main__":
    for i in range(10):
        loss_original, time_original, loss_vectorized, time_vectorized = test(i)
        print("loss_original:   ", loss_original, "time_original: ", time_original)
        print("loss_vectorized: ", loss_vectorized, "time_vectorized: ", time_vectorized)
codecov[bot] commented 8 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Comparison is base (d87fe76) 94.24% compared to head (1190f61) 94.19%. Report is 1 commits behind head on master.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #75 +/- ## ========================================== - Coverage 94.24% 94.19% -0.06% ========================================== Files 19 19 Lines 2103 2101 -2 Branches 653 653 ========================================== - Hits 1982 1979 -3 - Misses 91 92 +1 Partials 30 30 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.