Unity-Technologies / ml-agents

The Unity Machine Learning Agents Toolkit (ML-Agents) is an open-source project that enables games and simulations to serve as environments for training intelligent agents using deep reinforcement learning and imitation learning.
https://unity.com/products/machine-learning-agents
Other
16.81k stars 4.11k forks source link

ML Agents simulation extremely choppy #6127

Closed dhruvmsheth closed 1 week ago

dhruvmsheth commented 1 month ago

I've created a simple MLAgents script that sends the state information of the joints of the object and retrieves the velocity information for the actuators as input so that the joints can move to the desired state. I'm not using the MLAgents training procedure here, I'm using MLAgents in unity simply to interface with the python low level api, receive the state input, process it using a PIDcontroller and send the actuations for joints back to execute. When I create a C# script to execute this process natively on unity without MLAgents, the simulation is extremely smooth (My timestep is 0.02) and the motion is continuous. However, when I create the MLAgents C# script and remove the previous script and simply hit on play, the simulation becomes extremely choppy and starts executing the same process at 1FPS. My decision requester is set at 1 simply because I'm not training a model and I want the PIDController to retrieve the state information at all frames to be processed and imitate the mujoco simulation I had. I'm new to MLAgents and for that reason I might be making some obviously dumb mistake but any help would be appreciated. Sorry for tagging you @Balint-H but since you have worked with Mujoco + MLAgents, I thought tagging you might help.

Environment (please complete the following information):

Attaching the script and the MLAgents configuration here if that helps.

MLAgents script that causes the simulation to become choppy: ```csharp using System.Collections; using UnityEngine; using Unity.MLAgents; using Mujoco; using Unity.MLAgents.Sensors; using Unity.MLAgents.Actuators; using System; using System.Linq; using System.Collections.Generic; public class agentCallback : Agent { private double[] initialState; private double[] currentState; private MjScene mjScene; private bool isInitialized = false; private PIDControl pidController; private int nSteps = 0; public double targetHeight = 0.1; public double targetRadius; public double targetAngle; public override void Initialize() { base.Initialize(); StartCoroutine(InitializeMjScene()); } private IEnumerator InitializeMjScene() { while (MjScene.Instance == null) { yield return null; } mjScene = MjScene.Instance; isInitialized = true; } unsafe public override void OnEpisodeBegin() { // base.OnEpisodeBegin(); if (!isInitialized) { Debug.LogWarning("MjScene not initialized yet. Waiting..."); return; } MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; MujocoLib.mj_resetData(Model, Data); int stateSize = Model->nq - 4 + 3 + Model->nv + Model->na; initialState = getStateInformation(mjScene, stateSize); currentState = new double[stateSize]; Array.Copy(initialState, currentState, stateSize); pidController = new PIDControl(targetAngle, targetRadius, targetHeight); } unsafe public double[] getStateInformation(MjScene mjScene, int stateSize) { MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; double[] dq = new double[Model->nv]; fixed (double *dqPtr = dq) { MujocoLib.mj_differentiatePos(Model, dqPtr, 1, Model->qpos0, Data->qpos); } stateSize = Model->nq - 4 + 3 + Model->nv + Model->na; double[] currState = new double[stateSize]; int index = 0; List allJoints = new List(); allJoints.AddRange(this.GetComponentsInChildren()); allJoints.AddRange(this.GetComponentsInChildren()); List jointNames = new List() { "tilt-joint", "slew-joint", "luff-joint", "free-joint", }; // order in mujocopy Dictionary jointQposMap = new Dictionary() { {"tilt-joint", 1}, {"slew-joint", 1}, {"luff-joint", 1}, {"free-joint", 7} }; Dictionary jointVelMap = new Dictionary() { {"tilt-joint", 1}, {"slew-joint", 1}, {"luff-joint", 1}, {"free-joint", 6} }; foreach (var joint in allJoints) { // Debug.Log($"Name of joint is {jointQposMap}"); if (jointNames.Contains(joint.name)) { int qposCount = jointQposMap.GetValueOrDefault(joint.name); int startIndex = Model->jnt_qposadr[joint.MujocoId]; if (joint.name == "free-joint") { for (int i = 0; i < qposCount - 4; i++) { currState[index++] = Data->qpos[startIndex + i]; } } else { for (int i = 0; i < qposCount; i++) { currState[index++] = Data->qpos[startIndex + i]; } } } } // qpos rotation values not required for freeJoint // currState = currState.Take(currState.Length - 4).ToArray(); // add dq only for quaternion changes in free joint rotations int qvelFreeJnt = jointVelMap.GetValueOrDefault("free-joint"); // 3 dofs for translational velocity are not included for (int i = qvelFreeJnt - 3; i < qvelFreeJnt; i++) { currState[index++] = dq[i]; } foreach (var joint in allJoints) { if (jointNames.Contains(joint.name)) { int qvelCount = jointVelMap.GetValueOrDefault(joint.name); int startIndex = Model->jnt_dofadr[joint.MujocoId]; for (int i = 0; i < qvelCount; i++) { currState[index++] = Data->qvel[startIndex + i]; } } } // Actuator activation list for (int i = 0; i < Model->na; i++) { currState[index++] = Data->act[i]; } return currState; } // Start is called before the first frame update unsafe public override void CollectObservations(VectorSensor sensor) { // base.CollectObservations(sensor); MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; int stateSize = Model->nq - 4 + 3 + Model->nv + Model->na; currentState = getStateInformation(mjScene, stateSize); List observationList = currentState.Select(d => (float) d).ToList(); sensor.AddObservation(observationList); // size 22 } unsafe public override void OnActionReceived(ActionBuffers actions) { // base.OnActionReceived(actions); var continuousActions = actions.ContinuousActions; List actionList = new List(continuousActions.Length); for (int ii = 0; ii < continuousActions.Length; ii++) { actionList.Add(continuousActions[ii]); } if (mjScene != null) { MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; for (int i = 0; i < continuousActions.Length; i++) { Data->ctrl[i] = actionList[i]; } MujocoLib.mj_step(Model, Data); } // concluding SetReward(1f); // arbitrary unused reward // EndEpisode(); // we don't want to end the episode each time this is run } public override void Heuristic(in ActionBuffers actionsOut) { // actionsOut is a 4 input vector that takes the input from the PIDController // The way we do this is we take the current state as the input and then based on that // take the input from the PIDController to generate actions } } ```

MLAgents configuration details: image Thank you for the help!

Balint-H commented 1 month ago

Hello! This question would have been a good fit for the MuJoCo repo as well, as it concerns the plugin. A couple issues at first look:

dhruvmsheth commented 1 month ago

Thanks a lot for the suggestions @Balint-H! I made the updates and tried the script again but the frames still seem to be choppy. The suggestions were useful though, I completely missed them.

Script with the changes made ```csharp using System.Collections; using UnityEngine; using Unity.MLAgents; using Mujoco; using Unity.MLAgents.Sensors; using Unity.MLAgents.Actuators; using System; using System.Linq; using System.Collections.Generic; public class agentCallback : Agent { private double[] initialState; private double[] currentState; private MjScene mjScene; private PIDControl pidController; private List allJoints; private Dictionary jointQposMap; private Dictionary jointVelMap; private List jointNames; public double targetHeight = 0.1; public double targetRadius; public double targetAngle; private bool isInitialized; public override void Initialize() { base.Initialize(); MjScene.Instance.postInitEvent += OnMjSceneInitialized; } private void OnMjSceneInitialized(object sender, EventArgs e) { mjScene = MjScene.Instance; InitializeJointsAndStates(); isInitialized = true; } private void InitializeJointsAndStates() { allJoints = new List(); allJoints.AddRange(this.GetComponentsInChildren()); allJoints.AddRange(this.GetComponentsInChildren()); jointNames = new List() { "tilt-joint", "slew-joint", "luff-joint", "free-joint", }; jointQposMap = new Dictionary() { {"tilt-joint", 1}, {"slew-joint", 1}, {"luff-joint", 1}, {"free-joint", 7} }; jointVelMap = new Dictionary() { {"tilt-joint", 1}, {"slew-joint", 1}, {"luff-joint", 1}, {"free-joint", 6} }; } unsafe public override void OnEpisodeBegin() { if (!isInitialized) { Debug.LogWarning("MjScene not initialized yet. Waiting..."); return; } MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; MujocoLib.mj_resetData(Model, Data); int stateSize = Model->nq - 4 + 3 + Model->nv + Model->na; initialState = GetStateInformation(mjScene, stateSize); currentState = new double[stateSize]; Array.Copy(initialState, currentState, stateSize); pidController = new PIDControl(targetAngle, targetRadius, targetHeight); } unsafe private double[] GetStateInformation(MjScene mjScene, int stateSize) { MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; double[] dq = new double[Model->nv]; fixed (double* dqPtr = dq) { MujocoLib.mj_differentiatePos(Model, dqPtr, 1, Model->qpos0, Data->qpos); } double[] currState = new double[stateSize]; int index = 0; foreach (var joint in allJoints) { if (jointNames.Contains(joint.name)) { int qposCount = jointQposMap.GetValueOrDefault(joint.name); int startIndex = Model->jnt_qposadr[joint.MujocoId]; if (joint.name == "free-joint") { for (int i = 0; i < qposCount - 4; i++) { currState[index++] = Data->qpos[startIndex + i]; } } else { for (int i = 0; i < qposCount; i++) { currState[index++] = Data->qpos[startIndex + i]; } } } } int qvelFreeJnt = jointVelMap.GetValueOrDefault("free-joint"); for (int i = qvelFreeJnt - 3; i < qvelFreeJnt; i++) { currState[index++] = dq[i]; } foreach (var joint in allJoints) { if (jointNames.Contains(joint.name)) { int qvelCount = jointVelMap.GetValueOrDefault(joint.name); int startIndex = Model->jnt_dofadr[joint.MujocoId]; for (int i = 0; i < qvelCount; i++) { currState[index++] = Data->qvel[startIndex + i]; } } } for (int i = 0; i < Model->na; i++) { currState[index++] = Data->act[i]; } return currState; } unsafe public override void CollectObservations(VectorSensor sensor) { MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; int stateSize = Model->nq - 4 + 3 + Model->nv + Model->na; currentState = GetStateInformation(mjScene, stateSize); List observationList = currentState.Select(d => (float)d).ToList(); sensor.AddObservation(observationList); } unsafe public override void OnActionReceived(ActionBuffers actions) { var continuousActions = actions.ContinuousActions; if (mjScene != null) { MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; for (int i = 0; i < continuousActions.Length; i++) { Data->ctrl[i] = continuousActions[i]; } } } public override void Heuristic(in ActionBuffers actionsOut) {} } ```
Balint-H commented 1 month ago

As far as performance goes, I don't see anything in your script that I'd expect to be cause of slowdown. The only other thing I can think of, is how many dofs/instances do you have in your scene? I do quite heavy computations each frame in my scene (heavily unoptimized), and I can only have up to 5-6 humanoids while maintaining good framerates. If you are trying dozens then the plugin might struggle with that.

You could use the Unity profiler to identify your bottleneck.

Other minor considerations:

Balint-H commented 1 month ago

Do you get any exceptions or error messages in the editor? Also, if you have termination conditions, check whether they are triggering. Its possible they are continuously resetting your scene.

dhruvmsheth commented 1 month ago

Thanks for the prompt reply @Balint-H! My model is sufficiently lightweight and only has 4 joint with 1 free joint and there exists only 1 instance of it. I'll definitely try using profiler and let you know if I find something. The only exception I have in my terminal is "Couldn't connect to trainer on port 5004 using API version 1.5.0. Will perform inference instead.". I don't think I have any termination conditions, I manually terminate it using the python mlagents api. The interesting thing is that when I control the simulation locally using the mujoco plugin, it works perfectly fine and the physics is decent as well with high FPS. However, the problem starts coming in when I attach the MLAgents script to the agent. Even before I build and control the agent using the python api, just when I attach the script and hit play, the simulation becomes choppy and goes to 1FPS. That makes me suspect it's an initialization issue in MLAgents and I'm probably just missing something or maybe there's some interference between the script that locally controls it and the MLAgents script. I don't attach the script that locally controls it when MLAgents script is attached to the simulation. Additionally, when I change the decision period to 20, the simulation has a slightly higher FPS but the choppiness still exists. When the decision period is 1, the FPS is less than 1. Thanks for all the help!

Here's the script I use to locally control it:

Local control script ```csharp using UnityEngine; using Mujoco; using System; using System.Collections.Generic; public class SlewJointController : MonoBehaviour { static MjScene mjScene { get => MjScene.Instance; } private double[] initialState; private double[] currentState; private int nSteps = 0; public double targetHeight = 0.1; public double targetRadius; public double targetAngle; private PIDControl pidController; private double boatVel = 0; unsafe public void Start() { if (mjScene != null) { Debug.Log("mjScene was found in the scene"); } MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; //byte[] nameBytes = Encoding.ASCII.GetBytes("tilt-joint\0"); string namePtr = "tilt_joint"; int id = MujocoLib.mj_name2id(Model, (int)MujocoLib.mjtObj.mjOBJ_JOINT, namePtr); MujocoLib.mj_resetData(Model, Data); int stateSize = Model->nq - 4 + 3 + Model->nv + Model->na; initialState = getStateInformation(mjScene, stateSize); currentState = new double[stateSize]; Array.Copy(initialState, currentState, stateSize); Debug.Log($"New State after Step: {string.Join(", ", currentState)}"); targetHeight = 0.1; (targetRadius, targetAngle) = CartesianToPolar(-0.45, 0.0); pidController = new PIDControl(targetAngle, targetRadius, targetHeight); } unsafe public double[] Step(double[] userAction) { double[] newState = null; if (mjScene != null) { MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; // userAction contains the information about the 4 velocities // as per XML to be executed and then ctrl sets these 4 velocities // that were calculated and then the next phase retrieves the current state // of the system after execution of the velocities and then sends those velocities // back again for recalculation for (int i = 0; i < userAction.Length; i++) { Data->ctrl[i] = userAction[i]; } MujocoLib.mj_step(Model, Data); nSteps += 1; int stateSize = Model->nq - 4 + 3 + Model->nv + Model->na; newState = getStateInformation(mjScene, stateSize); } return newState; } unsafe public double[] getStateInformation(MjScene mjScene, int stateSize) { MujocoLib.mjModel_* Model = mjScene.Model; MujocoLib.mjData_* Data = mjScene.Data; double[] dq = new double[Model->nv]; fixed (double *dqPtr = dq) { MujocoLib.mj_differentiatePos(Model, dqPtr, 1, Model->qpos0, Data->qpos); } stateSize = Model->nq - 4 + 3 + Model->nv + Model->na; double[] currState = new double[stateSize]; int index = 0; List allJoints = new List(); allJoints.AddRange(this.GetComponentsInChildren()); allJoints.AddRange(this.GetComponentsInChildren()); List jointNames = new List() { "tilt-joint", "slew-joint", "luff-joint", "free-joint", }; // order in mujocopy Dictionary jointQposMap = new Dictionary() { {"tilt-joint", 1}, {"slew-joint", 1}, {"luff-joint", 1}, {"free-joint", 7} }; Dictionary jointVelMap = new Dictionary() { {"tilt-joint", 1}, {"slew-joint", 1}, {"luff-joint", 1}, {"free-joint", 6} }; foreach (var joint in allJoints) { Debug.Log($"Name of joint is {joint.name}"); if (jointNames.Contains(joint.name)) { int qposCount = jointQposMap.GetValueOrDefault(joint.name); int startIndex = Model->jnt_qposadr[joint.MujocoId]; if (string.Equals(joint.name, "free-joint", StringComparison.OrdinalIgnoreCase)) { Debug.Log("Entered inside"); for (int i = 0; i < qposCount - 4; i++) { currState[index++] = Data->qpos[startIndex + i]; } } else { for (int i = 0; i < qposCount; i++) { currState[index++] = Data->qpos[startIndex + i]; } } } } // qpos rotation values not required for freeJoint // currState = currState.Take(currState.Length - 4).ToArray(); // add dq only for quaternion changes in free joint rotations int qposFreeJnt = jointVelMap.GetValueOrDefault("free-joint"); // 3 dofs for translational velocity are not included for (int i = qposFreeJnt - 3; i < qposFreeJnt; i++) { currState[index++] = dq[i]; } foreach (var joint in allJoints) { if (jointNames.Contains(joint.name)) { int qvelCount = jointVelMap.GetValueOrDefault(joint.name); int startIndex = Model->jnt_dofadr[joint.MujocoId]; for (int i = 0; i < qvelCount; i++) { currState[index++] = Data->qvel[startIndex + i]; } } } // Actuator activation list for (int i = 0; i < Model->na; i++) { currState[index++] = Data->act[i]; } return currState; } unsafe public void FixedUpdate() { if (mjScene != null) { double boatVel = 0.05; List data = new List(); // Close the viewer automatically after 30 wall-seconds. float start = Time.time; if (Time.time - start < 1000) { float stepStart = Time.time; // Current_state is [pos_tilt_joint, pos_slew_joint, pos_luff_joint, pos_free_joint_r, pos_free_joint_p, pos_free_joint_y] Vector4 action = pidController.GeneratePidAction(currentState, boatVel, mjScene.Data->qvel); Debug.Log($"Value receieved from PID: {action.x}, {action.y}, {action.z}, {action.w}"); double[] actionArray = new double[] { action.x, action.y, action.z, action.w }; currentState = Step(actionArray); Debug.Log($"New State after Step: {string.Join(", ", currentState)}"); data.Add(mjScene.Data->qvel[1]); // Rudimentary time keeping, will drift relative to wall clock. float timeUntilNextStep = (float)mjScene.Model->opt.timestep - (Time.time - stepStart); } } } private MjHingeJoint FindJointByName(string name) { MjHingeJoint[] joints = GetComponentsInChildren(); foreach (MjHingeJoint joint in joints) { if (joint.name == name) { return joint; } } return null; } private (double, double) CartesianToPolar(double x, double y) { double radius = Math.Sqrt(x * x + y * y); double angle = Math.Atan2(y, x); float unityAngle = (float)(angle * Mathf.Rad2Deg); return (radius, unityAngle); } } ```
Balint-H commented 1 month ago

That type of scene should be able to easily reach a few hundred FPS even with ML Agents. Sorry that I can't take a more in depth look right now, but using the profiler, and trying to isolate what causes it (e.g. by commenting out the call to the getState function, or the call to differentiation API) could help. Also, if you are looking for examples of Mj + ML Agents, you can take a look at how scenes are constructed in ModularAgents, a repo of mine that has components, extensions and scenes for this kind of setup.

dhruvmsheth commented 1 month ago

Thanks for the help! I'll try that out and let you know

github-actions[bot] commented 1 week ago

This issue is stale because it has been open for 30 days with no activity.