robodhruv / visualnav-transformer

Official code and checkpoint release for mobile robot foundation models: GNM, ViNT, and NoMaD.
MIT License
425 stars 56 forks source link

Deployment on Go1 #6

Closed threeeyelidds closed 5 months ago

threeeyelidds commented 8 months ago

Hi, great work!

Is it possible for you to share how to deploy this on a Unitree GO1?

ajaysridhar0 commented 5 months ago

Hi @threeeyelidds,

To deploy our model on the Unitree Go1, I provided a sample ROS node that takes in waypoints from our navigation model and outputs velocities commands for the Go1. This script works with the unitree_legged_sdk.


import sys
import time


import robot_interface as sdk

import numpy as np
import yaml
from typing import Tuple

import rospy
from geometry_msgs.msg import Twist
from std_msgs.msg import Float32MultiArray, Bool

from topic_names import WAYPOINT_TOPIC
from ros_data import ROSData

# clip angle between -pi and pi
def clip_angle(angle):
    return np.mod(angle + np.pi, 2 * np.pi) - np.pi

CONFIG_PATH = "../config/robot.yaml"
with open(CONFIG_PATH, "r") as f:
    robot_config = yaml.safe_load(f)
MAX_V = robot_config["max_v"]
MAX_W = robot_config["max_w"]
VEL_TOPIC = robot_config["vel_navi_topic"]
EPS = 1e-8
WAYPOINT_TIMEOUT = 1 # seconds # TODO: tune this
DT = 1/robot_config["frame_rate"]

def pd_controller(waypoint: np.ndarray) -> Tuple[float]:
    """PD controller for the robot"""

    assert len(waypoint) == 2 or len(waypoint) == 4, "waypoint must be a 2D or 4D vector"
    if len(waypoint) == 2:
        dx, dy = waypoint
        dx, dy, hx, hy = waypoint
    # this controller only uses the predicted heading if dx and dy near zero
    if len(waypoint) == 4 and np.abs(dx) < EPS and np.abs(dy) < EPS:
        v = 0
        w = clip_angle(np.arctan2(hy, hx))/DT       
    elif np.abs(dx) < EPS:
        v =  0
        w = np.sign(dy) * np.pi/(2*DT)
        v = dx / DT
        w = np.arctan(dy/dx) / DT
    v = np.clip(v, 0, MAX_V)
    w = np.clip(w, -MAX_W, MAX_W)
    return v, w

def callback_drive(waypoint_msg: Float32MultiArray):
    """Callback function for the waypoint subscriber"""
    print("seting waypoint")

vel_msg = Twist()
waypoint = ROSData(WAYPOINT_TIMEOUT, name="waypoint")

RATE = 500 # Hz

if __name__ == '__main__':
    rospy.init_node('go1_vel_controller', anonymous=True)
    waypoint_sub = rospy.Subscriber(WAYPOINT_TOPIC, Float32MultiArray, callback_drive, queue_size=1)

    HIGHLEVEL = 0xee
    LOWLEVEL  = 0xff

    # udp = sdk.UDP(HIGHLEVEL, 8080, "", 8082)

    udp = sdk.UDP(HIGHLEVEL, 8080, "", 8082)

    cmd = sdk.HighCmd()
    state = sdk.HighState()

    motiontime = 0
    while not rospy.is_shutdown():
        motiontime = motiontime + 1


        cmd.mode = 0      # 0:idle, default stand      1:forced stand     2:walk continuously
        cmd.gaitType = 0
        cmd.speedLevel = 0
        cmd.footRaiseHeight = 0
        cmd.bodyHeight = 0
        cmd.euler = [0, 0, 0]
        cmd.velocity = [0, 0]
        cmd.yawSpeed = 0.0
        cmd.reserve = 0

        if waypoint.is_valid():
            v, w = pd_controller(waypoint.get())
            # cmd.gaitType = 
            cmd.footRaiseHeight = 0.01
            cmd.velocity = [v, 0]
            cmd.yawSpeed = w

        print("cmd", cmd.velocity, cmd.yawSpeed)


Please let me know if you have any issues.
