surgical-robotics-ai / SurgicAI

MIT License
8 stars 2 forks source link

A Fine-grained Platform for Data Collection and Benchmarking in Surgical Policy Learning

Prerequiste

This section introduces the necessary configuration you need.

System Requirement

Ubuntu Gymnasium Stable Baselines3 ROS Noetic Python Torch ambf d3rlpy

Installation

RL Training

This section introduce the basic procedure for model training with defined Gymnasium environment.

Run the SRC Environment

Make sure ROS and SRC is running before moving forward to the following steps. You can simply run the following command or refer to this link for details.

roscore
~/ambf/bin/lin-x86_64/ambf_simulator --launch_file ~/ambf/surgical_robotics_challenge/launch.yaml -l 0,1,3,4,13,14 -p 200 -t 1 --override_max_comm_freq 120 --override_min_comm_freq 120

Register the Gymnasium Environment

import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy
from Approach_env import SRC_approach
import numpy as np
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env
from RL_algo.td3_BC import TD3_BC
from RL_algo.DemoHerReplayBuffer import DemoHerReplayBuffer
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.noise import NormalActionNoise
import time
gym.envs.register(id="TD3_HER_BC", entry_point=SRC_approach)
env = gym.make("TD3_HER_BC", render_mode="human",reward_type = "sparse")

Initialize and Train the Model

Here is an example of model with TD3+HER+BC. While you may check Approach_training_HER.ipynb for more details.

model = TD3_BC(
    "MultiInputPolicy",
    env,
    replay_buffer_class=DemoHerReplayBuffer,
    policy_kwargs = dict(net_arch=dict(pi=[256, 256, 256], qf=[256, 256, 256])),
    replay_buffer_kwargs=dict(
        demo_transitions=episode_transitions, 
        goal_selection_strategy=goal_selection_strategy,
    ),
    episode_transitions=episode_transitions,
)
checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./First_version/Model_temp', name_prefix='SRC')
model.learn(total_timesteps=int(1000000), progress_bar=True,callback=checkpoint_callback,)
model.save("SRC")

Load the Model

model_path = "./Model/SRC_10000_steps.zip"
model = TD3_BC.load(model_path,env=env)
model.set_env(env=env)

Test the Model Prediction

obs,info = env.reset()
print(obs)
for i in range(10000):
    action, _state = model.predict(obs, deterministic=True)
    print(action)
    obs, reward, terminated,truncated, info = env.step(action)
    print(info)
    env.render()
    if terminated or truncated:
        obs, info = env.reset()

Low-level Training

The command lines above shows you a brief pipeline of how the pipeline works. In order to train a model specifically for a low level policy, you can directly run with the command below:

python3 RL_training_online.py --algorithm "$algorithm" --task_name "$task" --reward_type "$REWARD_TYPE" --total_timesteps "$TOTAL_TIMESTEPS" --save_freq "$SAVE_FREQ" --seed "$SEED" --trans_error "$TRANS_ERROR" --angle_error "$ANGLE_ERROR"

Model Evaluation

The command evaluates the success rate, trajectory length, and time steps across five policies with different random seeds.

python3 Model_evaluation.py --algorithm "$algorithm" --task_name "$task" --reward_type "$REWARD_TYPE" --trans_error "$TRANS_ERROR" --angle_error "$ANGLE_ERROR" --eval_seed "$EVAL_SEED"

High-level Training

See in High_level_HLP.ipynb for more details.

The following video demonstrates the complete suturing procedure by our training policy.

demo

Here's some progress demonstrating our pipeline's transition to the latest SRC, focusing on the low-level task: 'Place'.

New_SRC_demo

If you find our work userful, please cite it as:


@misc{wu2024surgicaifinegrainedplatformdata,
      title={SurgicAI: A Fine-grained Platform for Data Collection and Benchmarking in Surgical Policy Learning}, 
      author={Jin Wu and Haoying Zhou and Peter Kazanzides and Adnan Munawar and Anqi Liu},
      year={2024},
      eprint={2406.13865},
      archivePrefix={arXiv},
      primaryClass={cs.RO},
      url={https://arxiv.org/abs/2406.13865}, 
}