jc-bao / policy-adaptation-survey

This repository is for comparing the prevailing adaptive control method in both control and learning communities.
Apache License 2.0
7 stars 1 forks source link

🪶 vector solution for `sympy` #37

Open jc-bao opened 1 year ago

jc-bao commented 1 year ago

💬

Hi. I need some help in terms of writing a code in sympy. I have following 3 functions, each function is a 3d vector function. So there is actually 9 scalar function.

    # newton's law (9 equations)
    # quadrotor
    eq_quad_pos = sp.Matrix([0, 0, -m*g]) + f_rope + thrust_world - m * acc
    eq_quad_rot = torque + \
        sp.Matrix.cross(hook_offset_world, f_rope) - \
        sp.Matrix.cross(omega, I @ omega) - I @ alpha
    # object
    eq_obj_pos = -f_rope + sp.Matrix([0, 0, -mo*g]) - mo * acc_obj

The equation contains lots of variables, but what I want to solve these 5 variables with total 9 independent real number.

    acc = sp.MatrixSymbol("acc", 3, 1)
    alpha = sp.MatrixSymbol("alpha", 3, 1)
    theta_rope_ddot = sp.Symbol("theta_rope_ddot")
    phi_rope_ddot = sp.Symbol("phi_rope_ddot")
    f_rope_norm = sp.Symbol("f_rope_norm")

In my old implementation, since all variables are scalar, I only need to get all equations and change them into matrix form. When I want to get the value, I just use linear solve to get the wanted variables.

    eqs = [eq_quad_y, eq_quad_z, eq_quad_theta, eq_obj_y, eq_obj_z]
    # Solve for the acceleration
    A_taut_dyn = sp.zeros(5, 5)
    b_taut_dyn = sp.zeros(5, 1)
    for i in range(5):
        for j in range(5):
            A_taut_dyn[i, j] = eqs[i].coeff(states_dot_val[j])
        b_taut_dyn[i] = -eqs[i].subs([(states_dot_val[j], 0)
                                     for j in range(5)])
    # lambda A_taut_dyn
    A_taut_dyn_func = sp.lambdify(
        params + states_val + action, A_taut_dyn, "jax")
    b_taut_dyn_func = sp.lambdify(
        params + states_val + action, b_taut_dyn, "jax")

      # get variable values
        A = A_taut_dyn_func(*params, *states, *action)
        b = b_taut_dyn_func(*params, *states, *action)
        states_dot = jnp.linalg.solve(A, b).squeeze()

Now, please tell me how can I do the same thing for my new implementation.

jc-bao commented 1 year ago

🤖

import sympy as sp
import jax.numpy as jnp

# ... (Your existing code defining eq_quad_pos, eq_quad_rot, eq_obj_pos, ...)
# ... (Your existing code defining acc, alpha, theta_rope_ddot, phi_rope_ddot, f_rope_norm, ...)

# 1. Convert the matrix equations into a list of scalar equations
eqs = []
for eq in [eq_quad_pos, eq_quad_rot, eq_obj_pos]:
    eqs += [eq[i] for i in range(3)]

# 2. Create the coefficient matrix A and the constant vector b
unknowns = [acc[i] for i in range(3)] + [alpha[i] for i in range(3)] + [theta_rope_ddot, phi_rope_ddot, f_rope_norm]
num_eqs = len(eqs)
num_unknowns = len(unknowns)

A_new = sp.zeros(num_eqs, num_unknowns)
b_new = sp.zeros(num_eqs, 1)

for i in range(num_eqs):
    for j in range(num_unknowns):
        A_new[i, j] = eqs[i].coeff(unknowns[j])
    b_new[i] = -eqs[i].subs([(unknowns[j], 0) for j in range(num_unknowns)])

# 3. Lambdify the coefficient matrix and the constant vector
A_new_func = sp.lambdify(params + states_val + action, A_new, "jax")
b_new_func = sp.lambdify(params + states_val + action, b_new, "jax")

# 4. Solve the linear system using jnp.linalg.solve (or any other linear solver)
A_new_val = A_new_func(*params, *states, *action)
b_new_val = b_new_func(*params, *states, *action)
unknowns_values = jnp.linalg.solve(A_new_val, b_new_val).squeeze()