Open jc-bao opened 1 year ago
import sympy as sp
import jax.numpy as jnp
# ... (Your existing code defining eq_quad_pos, eq_quad_rot, eq_obj_pos, ...)
# ... (Your existing code defining acc, alpha, theta_rope_ddot, phi_rope_ddot, f_rope_norm, ...)
# 1. Convert the matrix equations into a list of scalar equations
eqs = []
for eq in [eq_quad_pos, eq_quad_rot, eq_obj_pos]:
eqs += [eq[i] for i in range(3)]
# 2. Create the coefficient matrix A and the constant vector b
unknowns = [acc[i] for i in range(3)] + [alpha[i] for i in range(3)] + [theta_rope_ddot, phi_rope_ddot, f_rope_norm]
num_eqs = len(eqs)
num_unknowns = len(unknowns)
A_new = sp.zeros(num_eqs, num_unknowns)
b_new = sp.zeros(num_eqs, 1)
for i in range(num_eqs):
for j in range(num_unknowns):
A_new[i, j] = eqs[i].coeff(unknowns[j])
b_new[i] = -eqs[i].subs([(unknowns[j], 0) for j in range(num_unknowns)])
# 3. Lambdify the coefficient matrix and the constant vector
A_new_func = sp.lambdify(params + states_val + action, A_new, "jax")
b_new_func = sp.lambdify(params + states_val + action, b_new, "jax")
# 4. Solve the linear system using jnp.linalg.solve (or any other linear solver)
A_new_val = A_new_func(*params, *states, *action)
b_new_val = b_new_func(*params, *states, *action)
unknowns_values = jnp.linalg.solve(A_new_val, b_new_val).squeeze()
💬
Hi. I need some help in terms of writing a code in sympy. I have following 3 functions, each function is a 3d vector function. So there is actually 9 scalar function.
The equation contains lots of variables, but what I want to solve these 5 variables with total 9 independent real number.
In my old implementation, since all variables are scalar, I only need to get all equations and change them into matrix form. When I want to get the value, I just use linear solve to get the wanted variables.
Now, please tell me how can I do the same thing for my new implementation.