Open Joshuaalbert opened 6 months ago
if __name__ == '__main__':
import numpy as np
from scipy.linalg import cholesky
# Function to create a symmetric positive-definite matrix
def create_spd_matrix(n):
A = np.random.rand(n, n)
A = np.dot(A, A.transpose()) # Make it symmetric
A += n * np.eye(n) # Make it positive-definite
return A
# Main demonstration
n = 5 # Dimension of the matrix
A = create_spd_matrix(n) # Create SPD matrix
A_lower = np.tril(A) # Extract lower triangular part
# Compute Cholesky decomposition using only the lower triangular part
L = cholesky(A_lower, lower=True)
# Verify the decomposition
A_reconstructed = L @ L.T
# Since we're working with floating point numbers, use np.allclose for comparison
is_correct = np.allclose(A_lower, np.tril(A_reconstructed))
print(f"Original Lower Triangular Part:\n{A_lower}")
print(f"Reconstructed Lower Triangular Part from Cholesky Decomposition:\n{np.tril(A_reconstructed)}")
print(f"Decomposition Correct: {is_correct}")
import jax.numpy as jnp
from jax import random
from jax.scipy.linalg import cholesky
# Function to create a symmetric positive-definite matrix using JAX
def create_spd_matrix(n, key):
A = random.normal(key, (n, n))
A = jnp.dot(A, A.T) # Make it symmetric
A += n * jnp.eye(n) # Make it positive-definite
return A
# Main demonstration
n = 5 # Dimension of the matrix
key = random.PRNGKey(0) # Random key for JAX
A = create_spd_matrix(n, key) # Create SPD matrix
A_lower = jnp.tril(A) # Extract lower triangular part
# Compute Cholesky decomposition using only the lower triangular part
L = cholesky(A_lower, lower=True)
# Verify the decomposition
A_reconstructed = L @ L.T
# Since we're working with floating point numbers, use jnp.allclose for comparison
is_correct = jnp.allclose(A_lower, jnp.tril(A_reconstructed))
print(f"Original Lower Triangular Part:\n{A_lower}")
print(f"Reconstructed Lower Triangular Part from Cholesky Decomposition:\n{jnp.tril(A_reconstructed)}")
print(f"Decomposition Correct: {is_correct}")
Enable a simulate of full 10.3min observation.
Input:
How:
(Na*Nt*Nd)^2
values in matrix form.