Source upload
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
fileFormatVersion: 2
|
||||
guid: ba4ec306400aafb4083615ca8c9d2ad2
|
||||
folderAsset: yes
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
@@ -0,0 +1,38 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import mlagents\n",
|
||||
"import gym_unity"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "c62a1b52b24525839a95f7ca2b53f501cc329096d80c6be9aea5c814c594ecdd"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.9.7 64-bit",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
fileFormatVersion: 2
|
||||
guid: 459e9567ac1c1a344af83899c81323ef
|
||||
DefaultImporter:
|
||||
externalObjects: {}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
@@ -0,0 +1,194 @@
|
||||
using System.Collections.Generic;
|
||||
using UnityEngine;
|
||||
using Unity.MLAgents;
|
||||
using Unity.MLAgents.Sensors;
|
||||
using Unity.MLAgents.Actuators;
|
||||
|
||||
public class RollerAgent : Agent
|
||||
{
|
||||
|
||||
public Transform Target;
|
||||
public Transform MinLocation;
|
||||
public GameObject UPArrow;
|
||||
public GameObject DownArrow;
|
||||
public GameObject LArrow;
|
||||
public GameObject RPArrow;
|
||||
|
||||
public int timeLimit = 8;
|
||||
public float forceMultiplier = 10;
|
||||
public float WINREWARD = 10.0f;
|
||||
public float FAILREWARD = -10.0f;
|
||||
private float minDistance = 0f;
|
||||
|
||||
|
||||
Rigidbody rBody;
|
||||
private float startTime = 0;
|
||||
|
||||
void Start()
|
||||
{
|
||||
rBody = GetComponent<Rigidbody>();
|
||||
}
|
||||
|
||||
//Episode开始时执行
|
||||
public override void OnEpisodeBegin()
|
||||
{
|
||||
startTime = Time.time;//Reset StartTime as now time
|
||||
// If the Agent fell, zero its momentum
|
||||
if (this.transform.localPosition.y < 0)
|
||||
{
|
||||
this.rBody.angularVelocity = Vector3.zero;
|
||||
this.rBody.velocity = Vector3.zero;
|
||||
this.transform.localPosition = new Vector3(0, 0.5f, 0);
|
||||
}
|
||||
|
||||
// Random Target Position
|
||||
Vector3 NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4);
|
||||
float dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition);
|
||||
while (dist <= 1.45f){
|
||||
NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4);
|
||||
dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition);
|
||||
}
|
||||
Target.localPosition = NewTargetPosition;
|
||||
minDistance = dist;
|
||||
MinLocation.localPosition = this.transform.localPosition;
|
||||
}
|
||||
|
||||
// 观察情报
|
||||
public override void CollectObservations(VectorSensor sensor)
|
||||
{
|
||||
// Target and Agent positions
|
||||
sensor.AddObservation(Target.localPosition.x);
|
||||
sensor.AddObservation(Target.localPosition.z);
|
||||
sensor.AddObservation(this.transform.localPosition.x);
|
||||
sensor.AddObservation(this.transform.localPosition.z);
|
||||
|
||||
// Agent velocity
|
||||
sensor.AddObservation(rBody.velocity.x);
|
||||
sensor.AddObservation(rBody.velocity.z);
|
||||
}
|
||||
|
||||
// 移动
|
||||
public void MoveBall(int action_x, int action_z)
|
||||
{
|
||||
// action = [0,0]
|
||||
Vector3 controlSignal = Vector3.zero;
|
||||
controlSignal.x = action_x;
|
||||
controlSignal.z = action_z;
|
||||
|
||||
GameObject[] Arrows = GameObject.FindGameObjectsWithTag("Arrow");
|
||||
foreach (GameObject gameObject in Arrows)
|
||||
{
|
||||
Destroy(gameObject, 0);
|
||||
}
|
||||
if (action_x == 1)
|
||||
{
|
||||
Instantiate(UPArrow, new Vector3(this.transform.localPosition.x + 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z+0.0f), Quaternion.Euler(0, 0, 0));
|
||||
}else if(action_x == -1)
|
||||
{
|
||||
Instantiate(DownArrow, new Vector3(this.transform.localPosition.x - 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 0.0f), Quaternion.Euler(0, 180, 0));
|
||||
}
|
||||
if (action_z == 1)
|
||||
{
|
||||
Instantiate(LArrow, new Vector3(this.transform.localPosition.x + 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 1.0f), Quaternion.Euler(0, 270, 0));
|
||||
}
|
||||
else if (action_z == -1)
|
||||
{
|
||||
Instantiate(RPArrow, new Vector3(this.transform.localPosition.x - 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z - 1.0f), Quaternion.Euler(0, 90, 0));
|
||||
}
|
||||
|
||||
|
||||
rBody.AddForce(controlSignal * forceMultiplier);
|
||||
}
|
||||
|
||||
// agent 输入处理
|
||||
public override void OnActionReceived(ActionBuffers actionBuffers)
|
||||
{
|
||||
// Actions, size = 2
|
||||
int inpX = 0;
|
||||
int inpZ = 0;
|
||||
inpX = actionBuffers.DiscreteActions[0];
|
||||
inpZ = actionBuffers.DiscreteActions[1];
|
||||
//Debug.Log(actionBuffers.DiscreteActions[0]);
|
||||
MoveBall(inpX, inpZ);
|
||||
// Rewards
|
||||
// 向target靠近则会获取reward,step靠近target的距离。
|
||||
float nowDistanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
|
||||
|
||||
|
||||
if (Time.time - startTime >= timeLimit || this.transform.localPosition.y < 0)
|
||||
{
|
||||
// Time UP or Fall from game area
|
||||
SetReward(FAILREWARD);
|
||||
//Debug.Log("Rewards = " + thisReward);
|
||||
EndEpisode();
|
||||
}
|
||||
else if(nowDistanceToTarget < 1.42f)
|
||||
{
|
||||
// Got the target
|
||||
SetReward(WINREWARD);
|
||||
//Debug.Log("Rewards = " + thisReward);
|
||||
EndEpisode();
|
||||
}
|
||||
else
|
||||
{
|
||||
float thisReward = 0f;
|
||||
if (nowDistanceToTarget < minDistance)
|
||||
{
|
||||
thisReward = minDistance - nowDistanceToTarget;
|
||||
minDistance = nowDistanceToTarget;
|
||||
MinLocation.localPosition = this.transform.localPosition;
|
||||
SetReward(thisReward);
|
||||
//Debug.Log("Rewards = " + thisReward);
|
||||
}
|
||||
else if (nowDistanceToTarget > minDistance)
|
||||
{
|
||||
thisReward = minDistance - nowDistanceToTarget;
|
||||
SetReward(thisReward);
|
||||
//Debug.Log("Rewards = " + thisReward);
|
||||
}
|
||||
else
|
||||
{
|
||||
thisReward = 0f;
|
||||
SetReward(thisReward);
|
||||
//Debug.Log("Rewards = " + thisReward);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 键盘控制调试
|
||||
public override void Heuristic(in ActionBuffers actionsOut)
|
||||
{
|
||||
//ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
|
||||
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
|
||||
|
||||
int inpX = 0;
|
||||
int inpY = 0;
|
||||
if (Input.GetKey(KeyCode.W))
|
||||
{
|
||||
inpX = 1;
|
||||
}
|
||||
else if (Input.GetKey(KeyCode.S))
|
||||
{
|
||||
inpX = -1;
|
||||
}
|
||||
else
|
||||
{
|
||||
inpX = 0;
|
||||
}
|
||||
if (Input.GetKey(KeyCode.A))
|
||||
{
|
||||
inpY = 1;
|
||||
}
|
||||
else if (Input.GetKey(KeyCode.D))
|
||||
{
|
||||
inpY = -1;
|
||||
}
|
||||
else
|
||||
{
|
||||
inpY = 0;
|
||||
}
|
||||
|
||||
discreteActions[0] = inpX;
|
||||
discreteActions[1] = inpY;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
fileFormatVersion: 2
|
||||
guid: c76cbcf0bd8edb346b372986bec61e97
|
||||
MonoImporter:
|
||||
externalObjects: {}
|
||||
serializedVersion: 2
|
||||
defaultReferences: []
|
||||
executionOrder: 0
|
||||
icon: {instanceID: 0}
|
||||
userData:
|
||||
assetBundleName:
|
||||
assetBundleVariant:
|
||||
Reference in New Issue
Block a user