UM-ARM-Lab / pytorch_kinematics

Robot kinematics implemented in pytorch
MIT License
394 stars 34 forks source link

Draft: forward kinematics w.r.t. differentiable kinematic parameters #32

Open JonathanKuelz opened 8 months ago

JonathanKuelz commented 8 months ago

Following #28 , here's a first draft regarding an implementation of forward kinematics that are differentiable w.r.t. joint offsets/link offsets.

This PR implements:

Also, it includes some more features I developed for debugging that are not necessarily required:

JonathanKuelz commented 8 months ago

Here's a script I use for debugging that makes use of the new features:

#!/usr/bin/env python3
# Author: Jonathan Külz
from time import time

import torch
from torch.optim import Adam

from pytorch_kinematics import visualize, SerialChain
from pytorch_kinematics.transforms.parameterized_transform import MDHTransform

d = "cuda" if torch.cuda.is_available() else "cpu"  # cuda faster from batch size ~256 on
tensor_type = torch.float32

def main(n_goals=2, b=1):
    """Create 3 random goals and optimize robot parameters and joint angles to reach them, one after the other"""
    goals = [20 * torch.randn((b, 3), dtype=tensor_type, device=d) for _ in range(n_goals)]

    # Initialize to Franka Emika panda
    # https://frankaemika.github.io/docs/control_parameters.html
    transforms = MDHTransform(parameters=torch.Tensor([
        [          0,       0,     0, 0],  # World joint
        [          0,       0, 0.333, 0],  # Joint 1
        [-torch.pi/2,       0,     0, 0],  # Joint 2
        [ torch.pi/2,       0, 0.316, 0],  # Joint 3
        [ torch.pi/2,  0.0825,     0, 0],  # Joint 4
        [-torch.pi/2, -0.0825, 0.384, 0],  # Joint 5
        [ torch.pi/2,       0,     0, 0],  # Joint 6
        [ torch.pi/2,   0.088,     0, 0],  # Joint 7
    ]), device=d, dtype=tensor_type, default_batch_size=(1,))
    robot = SerialChain.from_joint_transforms(transforms[1:, :])

    q = torch.rand((b, 7), dtype=tensor_type, device=d)
    fk_robot = robot.forward_kinematics(q, end_only=False)

    if b > 1:
        transforms = transforms.stack(*[transforms] * (b-1), dim=0)

    # Re-initializing transform, but with an additional "n-robot" batch dimension
    transforms = MDHTransform(parameters=transforms.parameters, device=d, dtype=tensor_type, default_batch_size=(1, 1))
    transforms.parameters.data[:, 1:, 3] = q
    M = transforms.get_matrix()
    fk_manual = torch.stack([torch.eye(4)] * b).to(d, tensor_type)
    for i in range(7):
        fk_manual = fk_manual @ M[:, i, ...]
        link = robot.get_link_names()[i]
        assert(torch.isclose(fk_manual, fk_robot[link].get_matrix(), atol=1e-3).all())

    X = MDHTransform(parameters=transforms.parameters.data[:, 1:, ...].detach()).to(transforms.device)
    optim = Adam([X.parameters], lr=1e-2)
    history = dict()
    for i, goal in enumerate(goals):
        local_history = []
        d_goal = torch.linalg.vector_norm(goal, axis=1)
        step = 0
        t0 = time()
        while True:
            step += 1
            optim.zero_grad()
            T = transforms[:, 0, ...].stack(X, dim=1)
            fk = robot.forward_kinematics(joint_offsets=T, end_only=True)
            goal_distance_loss = ((fk.get_matrix()[:, :3, 3] - goal) ** 2).sum()
            distance_overshoot = torch.relu(X.a.abs().sum(axis=1) + X.d.abs().sum(axis=1) - d_goal).sum()
            regularization = (X.alpha ** 2 + X.theta ** 2).sum()
            loss = (goal_distance_loss + distance_overshoot * 1e-2 + regularization * 1e-4) / b
            if b == 1:
                local_history.append((loss.item(), T.clone().to('cpu')))
            else:
                local_history.append((loss.item(), None))
            loss.backward()
            optim.step()
            if goal_distance_loss.item() / b < 1e-3:
                break
        print(f"Steps per second: {step / (time() - t0):.2f}")
        history[i] = local_history
        if b == 1:
            fk = robot.forward_kinematics(joint_offsets=local_history[-1][1], end_only=False)
            offsets = [fk[l] for l in robot.get_link_names()]
            visualize.visualize(offsets[0].stack(*offsets[1:]), show=True)

        # Plot loss history logarithmically
        import matplotlib.pyplot as plt
        plt.plot([h[0] for h in local_history])
        plt.yscale('log')
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.show()

if __name__ == '__main__':
    main()