ami-iit / adam

adam implements a collection of algorithms for calculating rigid-body dynamics in Jax, CasADi, PyTorch, and Numpy.
https://adam-docs.readthedocs.io/en/latest/
BSD 3-Clause "New" or "Revised" License
131 stars 20 forks source link

Update pytorch interface #65

Closed Giulero closed 9 months ago

Giulero commented 9 months ago

The aim of this PR is to update the Pytorch interface and habilitate its use in deep learning context (i.e. use differentiability). Uniforming to float64.

Giulero commented 9 months ago

I did a 10-minute test to verify that the pytorch interface can be used in a neural network.

I created a small NN that takes in input a constant value, it has some intermediate layers, with the last one a forward kinematics function. The network outputs 2 homogenous transforms (l_sole and l_lower_leg). The objective is to minimize the loss computed as the difference (for a certain definition of difference, in the example nn.MSELoss()) between the predicted transforms and the target ones.

This is the script:

# test if the forward kinematics is differentiable and can be used in a neural network

import logging
import icub_models
import torch.nn as nn
import torch
from adam import Representations
from adam.pytorch import KinDynComputations
import matplotlib.pyplot as plt
from rich.progress import track

# torch.random.seed(42)
# torch.set_default_dtype(torch.float64)

model_path = str(icub_models.get_model_file("iCubGazeboV2_5"))

joints_name_list = [
    "torso_pitch",
    "torso_roll",
    "torso_yaw",
    "l_shoulder_pitch",
    "l_shoulder_roll",
    "l_shoulder_yaw",
    "l_elbow",
    "r_shoulder_pitch",
    "r_shoulder_roll",
    "r_shoulder_yaw",
    "r_elbow",
    "l_hip_pitch",
    "l_hip_roll",
    "l_hip_yaw",
    "l_knee",
    "l_ankle_pitch",
    "l_ankle_roll",
    "r_hip_pitch",
    "r_hip_roll",
    "r_hip_yaw",
    "r_knee",
    "r_ankle_pitch",
    "r_ankle_roll",
]

logging.basicConfig(level=logging.DEBUG)
logging.debug("Showing the robot tree.")

root_link = "root_link"
comp = KinDynComputations(model_path, joints_name_list, root_link)
comp.set_frame_velocity_representation(Representations.MIXED_REPRESENTATION)

n_dofs = len(joints_name_list)

# base pose quantities
# pytorch random quantities
xyz = (torch.rand(3) - 0.5) * 5
rpy = (torch.rand(3) - 0.5) * 5
base_vel = (torch.rand(6) - 0.5) * 5
# joints quantitites
joints_val = (torch.rand(n_dofs) - 0.5) * 5
joints_dot_val = (torch.rand(n_dofs) - 0.5) * 5

from adam.pytorch.torch_like import SpatialMath

math = SpatialMath()

g = torch.tensor([0, 0, -9.80665])
H_b = math.H_from_Pos_RPY(xyz, rpy).array

random_joints = torch.randn(n_dofs, requires_grad=True)
# forward kinematics, to compute targets
output_lsole = comp.forward_kinematics("l_sole", H_b, random_joints)
output_llower_leg = comp.forward_kinematics("l_lower_leg", H_b, random_joints)

class fkNN(nn.Module):
    def __init__(self):
        super(fkNN, self).__init__()
        self.comp = KinDynComputations(model_path, joints_name_list, root_link)
        self.comp.set_frame_velocity_representation(
            Representations.MIXED_REPRESENTATION
        )
        self.fc1 = nn.Linear(n_dofs, 6 + n_dofs)
        self.act1 = nn.ReLU()
        self.fc2 = nn.Linear(6 + n_dofs, 6 + n_dofs)
        self.act2 = nn.ReLU()
        self.out = nn.Linear(6 + n_dofs, n_dofs)

    def forward(self, joints_val):
        x = self.act1(self.fc1(joints_val))
        x = self.act2(self.fc2(x))
        x = self.out(x)
        # pass to the forward kinematics
        return self.comp.forward_kinematics(
            "l_sole", H_b, x
        ), self.comp.forward_kinematics("l_lower_leg", H_b, x)

    def get_joints(self, value):
        x = self.act1(self.fc1(value))
        x = self.act2(self.fc2(x))
        return self.out(x)

fkNN = fkNN()

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(fkNN.parameters(), lr=0.01)

joints_val_torch = torch.ones(n_dofs, requires_grad=True)
target_lsole = output_lsole.clone().detach()
target_llower_leg = output_llower_leg.clone().detach()
loss_array = []
for _ in track(range(30000), description="Training"):
    optimizer.zero_grad()
    output_lsole, output_llower_leg = fkNN(joints_val_torch)
    loss = criterion(target_lsole, output_lsole) + criterion(
        target_llower_leg, output_llower_leg
    )
    loss.backward()
    optimizer.step()
    # print(f"loss: {loss}")
    loss_array.append(loss.detach().numpy())
print(f"target lsole: {target_lsole}")
print(f"output lsole: {output_lsole}")
print(f"target llower leg: {target_llower_leg}")
print(f"output llower leg: {output_llower_leg}")
print(f"original joints: {random_joints}")
print(f"nn joints: {fkNN.get_joints(joints_val_torch)}")

steps = range(len(loss_array))
plt.plot(steps, loss_array)
plt.xlabel("steps")
plt.ylabel("loss")
plt.show()

The output of the network, i.e. the predicted homogeneous transforms of l_sole and l_lower_leg converge to the target values.

target lsole: tensor([[-0.4314,  0.7503, -0.5010,  1.1846],
        [ 0.7966,  0.5775,  0.1789,  1.1001],
        [ 0.4235, -0.3219, -0.8468, -0.6390],
        [ 0.0000,  0.0000,  0.0000,  1.0000]], dtype=torch.float64)
output lsole: tensor([[-0.4314,  0.7503, -0.5010,  1.1846],
        [ 0.7966,  0.5775,  0.1789,  1.1001],
        [ 0.4235, -0.3219, -0.8468, -0.6390],
        [ 0.0000,  0.0000,  0.0000,  1.0000]], dtype=torch.float64,
       grad_fn=<MmBackward0>)
target llower leg: tensor([[-0.4487,  0.8847,  0.1266,  1.1776],
        [ 0.7305,  0.2815,  0.6222,  1.2364],
        [ 0.5148,  0.3716, -0.7726, -0.8485],
        [ 0.0000,  0.0000,  0.0000,  1.0000]], dtype=torch.float64)
output llower leg: tensor([[-0.4487,  0.8847,  0.1266,  1.1776],
        [ 0.7305,  0.2815,  0.6222,  1.2364],
        [ 0.5148,  0.3716, -0.7726, -0.8485],
        [ 0.0000,  0.0000,  0.0000,  1.0000]], dtype=torch.float64,
       grad_fn=<MmBackward0>)

The output of the intermediate layers, i.e, the joints that we insert in the forward kinematics function converge to the joints values used to compute the target values (for the left leg joints, the 6 values before the 6 last values)

original joints: tensor([ 1.8165, -0.4723,  0.5580, -1.3924,  0.2683, -0.2554,  1.2683, -0.8576,
        -0.0050, -1.4101, -0.6394,  0.5347, -0.5572,  0.8706,  2.6310, -0.1141,
         0.7860,  0.1291, -0.4271,  0.0490, -1.9311, -0.9298,  0.2214],
       requires_grad=True)
nn joints: tensor([-0.0812, -0.1694,  0.3457, -0.3293, -0.7453,  0.0631,  0.4840, -0.2661,
        -0.3233,  0.0550,  0.3637,  0.5347, -0.5572,  0.8706,  2.6310, -0.1141,
         0.7860,  0.1906,  0.6204,  0.4075,  0.0221,  0.3031, -0.3142],

cc. @Zweisteine96

Giulero commented 9 months ago

Thanks @traversaro! Fixed

Giulero commented 9 months ago

Thanks! :)