yandexdataschool / Practical_RL

A course in reinforcement learning in the wild
The Unlicense
5.92k stars 1.7k forks source link

Questioning Regards Conjugate Gradient Algorithm #516

Closed AI-Ahmed closed 1 year ago

AI-Ahmed commented 1 year ago

I have been searching to understand more about the conjugate gradient algorithm. It was really genius idea from Schulman and prof. Pieter Abbeel, et al.

The thank is also for you guys contributing to this and implementing the algorithm.

def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10):
    """
    This method solves system of equation Ax=b using an
    iterative method called conjugate gradients

    :f_Ax: function that returns Ax
    :b: targets for Ax
    :cg_iters: how many iterations this method should do
    :residual_tol: epsilon for stability
    """
    p = b.clone() # a basis for R^n
    r = b.clone() # residual
    x = torch.zeros_like(b) #input vector
    rdotr = torch.sum(r*r)
    for i in range(cg_iters):
        z = f_Ax(p)
        v = rdotr / (torch.sum(p*z) + 1e-8)
        x += v * p
        r -= v * z
        newrdotr = torch.sum(r*r)
        mu = newrdotr / (rdotr + 1e-8)
        p = r + mu * p
        rdotr = newrdotr
        if rdotr < residual_tol:
            break
    return x

I was wondering, when I found the mathematical algorithm, that there were things that confused me! image

In the algorithm;

  1. $r_{k}^{\top}r_k$ – it corresponds to rdotr, but I don't understand why didn't we transpose one of the r before multiplying it by itself?
  2. The same with $pk^{\top}A{p_k}$ – it corresponds to (torch.sum(p*z) + 1e-8).

Please, if I am missing something, direct me. Thanks,

AI-Ahmed commented 1 year ago

I found this an exciting answer that was really intuitive to me (sorry if my question was silly, LA is essential to dive deep into it)!

image

I thought that $r_k$ are squared matrix, but now, I understand that $r_k^{\top}$ is $1 \times n \ \text{matrix}$ while $r_k$ is $n \times 1 \ \text{matrix}$.

That means – $$r_k^{\top}rk \ = \ \sum^{n}{k=1}{r_k \cdot r_k}$$ Ref: https://math.stackexchange.com/questions/1853808/product-of-a-vector-and-its-transpose-projections

If there is anything else, please let me know!