Unity-Technologies / ml-agents

The Unity Machine Learning Agents Toolkit (ML-Agents) is an open-source project that enables games and simulations to serve as environments for training intelligent agents using deep reinforcement learning and imitation learning.
https://unity.com/products/machine-learning-agents
Other
16.93k stars 4.13k forks source link

Agent Action- Mouse Velocity, Roller Ball ain't learning? #3001

Closed Zilch123 closed 4 years ago

Zilch123 commented 4 years ago

Mouse as input rather the keyboard keys. Took velocity of the mouse as current mouse position - Prev Mouse position. VectorActions are the velocity of the mouse.

When it's mouse position as input it learns. While using the mouse velocity doesn't work, "No episode was completed since last summary." For player mode, the velocity is calculated in heuristic and action is given as velocity.

The Velocity of the mouse controls the ball velocity. Deltax, Deltay Usually ranges from -1.0 to 1.0.

Kinematics is enabled and scripting controls the whole game.


using UnityEngine;
using MLAgents;
using System.Collections;
using System;
using System.Threading;
using Random = UnityEngine.Random;
using UnityEngine.UI;

public class RollerAgent : Agent
{
    public Rigidbody rBody;
    public Rigidbody Target;
    public Rigidbody Trap;
    public float speed;
    public Text scoreTxt;
    private RayPerception3D rayPer;
    private float mouse_posx, mouse_posz, prevmouse_posx, prevmouse_posz, maxvel;
    private bool temp;
    private Vector3 temp0, below;
    private int score, trapped, fallen, numberofsteps;
    private float[] vel_action, pos_action, prev_pos_action;
    //Initiallise variables
    void Start()
    {
        rBody = GetComponent<Rigidbody>();
        Target = GameObject.FindWithTag("Target").GetComponent<Rigidbody>();
        Trap = GameObject.FindWithTag("Trap").GetComponent<Rigidbody>();
        rayPer = GetComponent<RayPerception3D>();
        mouse_posx = 0.0f;
        mouse_posz = 0.0f;
        prevmouse_posx = 0.0f;
        prevmouse_posz = 0.0f;
        maxvel = 0.0f;
        Cursor.visible = false;
        Screen.SetResolution(10000, 10000, true);
        below = transform.TransformDirection(Vector3.down);
        score = 0;
        fallen = 0;
        trapped = 0;
        numberofsteps = 0;
        DisplayScore();
        Screen.lockCursor = true;
        temp = true;
        vel_action = new float[2] {0f, 0f};
        pos_action = new float[2];
        prev_pos_action = new float[2];

        //Cursor.visible = false;
    }

    //Called everytime done() is called
    public override void AgentReset()
    {
        Debug.Log("numberofsteps" + numberofsteps);
        numberofsteps = 0;
        if (!Physics.Raycast(this.transform.position, below, 10))
        {
            // If the Agent fell, zero its momentum
            this.rBody.angularVelocity = Vector3.zero;
            this.rBody.velocity = Vector3.zero;
            this.transform.position = new Vector3(0, 0.5f, 0);
        }
        StartCoroutine(Wait_and_put_target_n_Trap(1.1f));
        //Thread.Sleep(200);               //comment in training mode

    }

    IEnumerator Wait_and_put_target_n_Trap(float time_)
    {
        //yield return new WaitForSeconds(time_);

        Vector3 pos_Trap;
        do
        {
            // Move the trap to a new spot
            var x_Tr = Random.value * 8 - 4;
            var z_Tr = Random.value * 8 - 4;
            pos_Trap = new Vector3(x_Tr, 0.5f, z_Tr);

            if (Vector3.Distance(pos_Trap, this.transform.position) > 2.5)
            {
                Trap.position = pos_Trap;
                GameObject.Find("Trap").transform.localScale = new Vector3(0, 0, 0);
                break;
            }

        } while (true);

        do
        {
            // Move the target to a new spot
            var x_Ta = Random.value * 8 - 4;
            var z_Ta = Random.value * 8 - 4;
            var pos_Target = new Vector3(x_Ta, 0.5f, z_Ta);
            if (Vector3.Distance(pos_Trap, pos_Target) > 2.5 && Vector3.Distance(this.transform.position, pos_Target) > 1)
            {
                Target.position = pos_Target;
                GameObject.Find("Target").transform.localScale = new Vector3(0, 0, 0);
                break;
            }

        } while (true);

        yield return new WaitForSecondsRealtime(time_);
        GameObject.Find("Target").transform.localScale = new Vector3(1, 1, 1);
        GameObject.Find("Trap").transform.localScale = new Vector3(1, 1, 1);
    }

    // Inputs of the NN
    public override void CollectObservations()
    {
        //AddVectorObs(this.transform.position);
        //AddVectorObs(Trap.position);
        //AddVectorObs(Target.position);
        //AddVectorObs(rBody.velocity);
        // Ray preception
        float rayDistance = 50f;
        float[] rayAngles = { 0f, 10f, 20f, 30f, 40f, 50f, 60f, 70f, 80f, 90f, 100f, 110f, 120f, 130f, 140f, 150f,
        160f, 170f, 180f, 190f, 200f, 210f, 220f, 230f, 240f, 250f, 260f, 270f, 280f, 290f, 300f, 310f, 320f, 330f, 340f, 350f};
        string[] detectableObjects = { "Target", "Trap" };
        AddVectorObs(rayPer.Perceive(rayDistance, rayAngles, detectableObjects, 0f, 0f));
        // Agent local velocity
        Vector3 localVelocity = transform.InverseTransformDirection(rBody.velocity);
        AddVectorObs(localVelocity.x);
        AddVectorObs(localVelocity.z);
    }

    public void DisplayScore()
    {
        scoreTxt.text = "Score: " + score.ToString() + "\n" 
                    + "Trapped: " + trapped.ToString() + "\n" 
                      +"Fallen: " + fallen.ToString() ;
    }

    //Human Play mode
    public override float[] Heuristic()
    {
        pos_action[0] = Input.GetAxis("Horizontal");
        pos_action[1] = Input.GetAxis("Vertical");

        if (temp)
        {
            prev_pos_action[0] = pos_action[0];
            prev_pos_action[1] = pos_action[1];
            temp = false;
            return vel_action;
        }
        vel_action[0] = pos_action[0] - prev_pos_action[0];
        vel_action[1] = pos_action[1] - prev_pos_action[1];

        prev_pos_action[0] = pos_action[0];
        prev_pos_action[1] = pos_action[1];

        Debug.Log("Velaction"+vel_action[0]+vel_action[1]);
        return vel_action;
    }

    //Called everyframe
    public override void AgentAction(float[] vectorAction, string textAction)
    {
        maxvel = Math.Abs(vectorAction[0]) + Math.Abs(vectorAction[1]);
        if (maxvel >= 0.05)
        {
            numberofsteps = numberofsteps + 1;
        }

        temp0 += new Vector3(vectorAction[0], 0.0f, vectorAction[1]);
        rBody.transform.Translate(temp0);

        // Rewards
        SetReward(-0.01f);

        float distanceToTarget = Vector3.Distance(this.transform.position,
                                                  Target.position);
        // Trapped
        float distanceToTrap = Vector3.Distance(this.transform.position,
                                                  Trap.position);
        // Reached target
        if (distanceToTarget < 0.8f)
        {
            SetReward(1.0f);
            score = score + 1;
            Done();
            DisplayScore();
        }

        // Reached Trap
        if (distanceToTrap < 0.8f)
        {
            SetReward(-1.0f);
            score = score - 1;
            trapped = trapped + 1;
            Done();
            DisplayScore();
        }

        // trying to go outside platform
        if (!Physics.Raycast(this.transform.position, below, 10))
        {
            //score = score - 1;
            //fallen = fallen + 1;
            //SetReward(-1f);
            //Done();
            //DisplayScore();

            if (this.transform.position.x >= 5f)
                this.transform.position = new Vector3(5f, 0.5f, this.transform.position.z);
            else if (this.transform.position.x <= -5f)
                this.transform.position = new Vector3(-5f, 0.5f, this.transform.position.z);
            else if (this.transform.position.z >= 5f)
                this.transform.position = new Vector3(this.transform.position.x, 0.5f, 5f);
            else if (this.transform.position.z <= -5f)
                this.transform.position = new Vector3(this.transform.position.x, 0.5f, -5f);

            if (this.transform.position.x >= 5f && this.transform.position.z >= 5f)
                this.transform.position = new Vector3(5f, 0.5f, 5f);
            if (this.transform.position.x >= 5f && this.transform.position.z <= -5f)
                this.transform.position = new Vector3(5f, 0.5f, -5f);
            if (this.transform.position.x <= -5f && this.transform.position.z >= 5f)
                this.transform.position = new Vector3(-5f, 0.5f, 5f);
            if (this.transform.position.x <= -5f && this.transform.position.z <= -5f)
                this.transform.position = new Vector3(-5f, 0.5f, -5f);

        }

    }
}`
Zilch123 commented 4 years ago

Solve by changing the reward structure and by tuning the hyperparameters. ~80% accuracy

github-actions[bot] commented 3 years ago

This thread has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.