using System.Collections.Generic; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; using Unity.MLAgents.Actuators; public class RollerAgent : Agent { public Transform Target; public Transform MinLocation; public GameObject UPArrow; public GameObject DownArrow; public GameObject LArrow; public GameObject RPArrow; public int timeLimit = 8; public float forceMultiplier = 10; public float WINREWARD = 10.0f; public float FAILREWARD = -10.0f; private float minDistance = 0f; Rigidbody rBody; private float startTime = 0; void Start() { rBody = GetComponent(); } //Episode开始时执行 public override void OnEpisodeBegin() { startTime = Time.time;//Reset StartTime as now time // If the Agent fell, zero its momentum if (this.transform.localPosition.y < 0) { this.rBody.angularVelocity = Vector3.zero; this.rBody.velocity = Vector3.zero; this.transform.localPosition = new Vector3(0, 0.5f, 0); } // Random Target Position Vector3 NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4); float dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition); while (dist <= 1.45f){ NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4); dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition); } Target.localPosition = NewTargetPosition; minDistance = dist; MinLocation.localPosition = this.transform.localPosition; } // 观察情报 public override void CollectObservations(VectorSensor sensor) { // Target and Agent positions sensor.AddObservation(Target.localPosition.x); sensor.AddObservation(Target.localPosition.z); sensor.AddObservation(this.transform.localPosition.x); sensor.AddObservation(this.transform.localPosition.z); // Agent velocity sensor.AddObservation(rBody.velocity.x); sensor.AddObservation(rBody.velocity.z); } // 移动 public void MoveBall(int action_x, int action_z) { // action = [0,0] Vector3 controlSignal = Vector3.zero; controlSignal.x = action_x; controlSignal.z = action_z; GameObject[] Arrows = GameObject.FindGameObjectsWithTag("Arrow"); foreach (GameObject gameObject in Arrows) { Destroy(gameObject, 0); } if (action_x == 1) { Instantiate(UPArrow, new Vector3(this.transform.localPosition.x + 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z+0.0f), Quaternion.Euler(0, 0, 0)); }else if(action_x == -1) { Instantiate(DownArrow, new Vector3(this.transform.localPosition.x - 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 0.0f), Quaternion.Euler(0, 180, 0)); } if (action_z == 1) { Instantiate(LArrow, new Vector3(this.transform.localPosition.x + 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 1.0f), Quaternion.Euler(0, 270, 0)); } else if (action_z == -1) { Instantiate(RPArrow, new Vector3(this.transform.localPosition.x - 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z - 1.0f), Quaternion.Euler(0, 90, 0)); } rBody.AddForce(controlSignal * forceMultiplier); } // agent 输入处理 public override void OnActionReceived(ActionBuffers actionBuffers) { // Actions, size = 2 int inpX = 0; int inpZ = 0; inpX = actionBuffers.DiscreteActions[0]; inpZ = actionBuffers.DiscreteActions[1]; //Debug.Log(actionBuffers.DiscreteActions[0]); MoveBall(inpX, inpZ); // Rewards // 向target靠近则会获取reward,step靠近target的距离。 float nowDistanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition); if (Time.time - startTime >= timeLimit || this.transform.localPosition.y < 0) { // Time UP or Fall from game area SetReward(FAILREWARD); //Debug.Log("Rewards = " + thisReward); EndEpisode(); } else if(nowDistanceToTarget < 1.42f) { // Got the target SetReward(WINREWARD); //Debug.Log("Rewards = " + thisReward); EndEpisode(); } else { float thisReward = 0f; if (nowDistanceToTarget < minDistance) { thisReward = minDistance - nowDistanceToTarget; minDistance = nowDistanceToTarget; MinLocation.localPosition = this.transform.localPosition; SetReward(thisReward); //Debug.Log("Rewards = " + thisReward); } else if (nowDistanceToTarget > minDistance) { thisReward = minDistance - nowDistanceToTarget; SetReward(thisReward); //Debug.Log("Rewards = " + thisReward); } else { thisReward = 0f; SetReward(thisReward); //Debug.Log("Rewards = " + thisReward); } } } // 键盘控制调试 public override void Heuristic(in ActionBuffers actionsOut) { //ActionSegment continuousActions = actionsOut.ContinuousActions; ActionSegment discreteActions = actionsOut.DiscreteActions; int inpX = 0; int inpY = 0; if (Input.GetKey(KeyCode.W)) { inpX = 1; } else if (Input.GetKey(KeyCode.S)) { inpX = -1; } else { inpX = 0; } if (Input.GetKey(KeyCode.A)) { inpY = 1; } else if (Input.GetKey(KeyCode.D)) { inpY = -1; } else { inpY = 0; } discreteActions[0] = inpX; discreteActions[1] = inpY; } }