Source upload

This commit is contained in:
2024-03-05 19:02:19 +09:00
commit 26d5e01d34
364 changed files with 41590 additions and 0 deletions
+8
View File
@@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: ba4ec306400aafb4083615ca8c9d2ad2
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:
+38
View File
@@ -0,0 +1,38 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import mlagents\n",
"import gym_unity"
]
}
],
"metadata": {
"interpreter": {
"hash": "c62a1b52b24525839a95f7ca2b53f501cc329096d80c6be9aea5c814c594ecdd"
},
"kernelspec": {
"display_name": "Python 3.9.7 64-bit",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
+7
View File
@@ -0,0 +1,7 @@
fileFormatVersion: 2
guid: 459e9567ac1c1a344af83899c81323ef
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:
+194
View File
@@ -0,0 +1,194 @@
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
public class RollerAgent : Agent
{
public Transform Target;
public Transform MinLocation;
public GameObject UPArrow;
public GameObject DownArrow;
public GameObject LArrow;
public GameObject RPArrow;
public int timeLimit = 8;
public float forceMultiplier = 10;
public float WINREWARD = 10.0f;
public float FAILREWARD = -10.0f;
private float minDistance = 0f;
Rigidbody rBody;
private float startTime = 0;
void Start()
{
rBody = GetComponent<Rigidbody>();
}
//Episode开始时执行
public override void OnEpisodeBegin()
{
startTime = Time.time;//Reset StartTime as now time
// If the Agent fell, zero its momentum
if (this.transform.localPosition.y < 0)
{
this.rBody.angularVelocity = Vector3.zero;
this.rBody.velocity = Vector3.zero;
this.transform.localPosition = new Vector3(0, 0.5f, 0);
}
// Random Target Position
Vector3 NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4);
float dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition);
while (dist <= 1.45f){
NewTargetPosition = new Vector3(Random.value * 8 - 4,0.5f,Random.value * 8 - 4);
dist = Vector3.Distance(this.transform.localPosition, NewTargetPosition);
}
Target.localPosition = NewTargetPosition;
minDistance = dist;
MinLocation.localPosition = this.transform.localPosition;
}
// 观察情报
public override void CollectObservations(VectorSensor sensor)
{
// Target and Agent positions
sensor.AddObservation(Target.localPosition.x);
sensor.AddObservation(Target.localPosition.z);
sensor.AddObservation(this.transform.localPosition.x);
sensor.AddObservation(this.transform.localPosition.z);
// Agent velocity
sensor.AddObservation(rBody.velocity.x);
sensor.AddObservation(rBody.velocity.z);
}
// 移动
public void MoveBall(int action_x, int action_z)
{
// action = [0,0]
Vector3 controlSignal = Vector3.zero;
controlSignal.x = action_x;
controlSignal.z = action_z;
GameObject[] Arrows = GameObject.FindGameObjectsWithTag("Arrow");
foreach (GameObject gameObject in Arrows)
{
Destroy(gameObject, 0);
}
if (action_x == 1)
{
Instantiate(UPArrow, new Vector3(this.transform.localPosition.x + 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z+0.0f), Quaternion.Euler(0, 0, 0));
}else if(action_x == -1)
{
Instantiate(DownArrow, new Vector3(this.transform.localPosition.x - 1.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 0.0f), Quaternion.Euler(0, 180, 0));
}
if (action_z == 1)
{
Instantiate(LArrow, new Vector3(this.transform.localPosition.x + 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z + 1.0f), Quaternion.Euler(0, 270, 0));
}
else if (action_z == -1)
{
Instantiate(RPArrow, new Vector3(this.transform.localPosition.x - 0.0f, this.transform.localPosition.y + 0.5f, this.transform.localPosition.z - 1.0f), Quaternion.Euler(0, 90, 0));
}
rBody.AddForce(controlSignal * forceMultiplier);
}
// agent 输入处理
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Actions, size = 2
int inpX = 0;
int inpZ = 0;
inpX = actionBuffers.DiscreteActions[0];
inpZ = actionBuffers.DiscreteActions[1];
//Debug.Log(actionBuffers.DiscreteActions[0]);
MoveBall(inpX, inpZ);
// Rewards
// 向target靠近则会获取rewardstep靠近target的距离。
float nowDistanceToTarget = Vector3.Distance(this.transform.localPosition, Target.localPosition);
if (Time.time - startTime >= timeLimit || this.transform.localPosition.y < 0)
{
// Time UP or Fall from game area
SetReward(FAILREWARD);
//Debug.Log("Rewards = " + thisReward);
EndEpisode();
}
else if(nowDistanceToTarget < 1.42f)
{
// Got the target
SetReward(WINREWARD);
//Debug.Log("Rewards = " + thisReward);
EndEpisode();
}
else
{
float thisReward = 0f;
if (nowDistanceToTarget < minDistance)
{
thisReward = minDistance - nowDistanceToTarget;
minDistance = nowDistanceToTarget;
MinLocation.localPosition = this.transform.localPosition;
SetReward(thisReward);
//Debug.Log("Rewards = " + thisReward);
}
else if (nowDistanceToTarget > minDistance)
{
thisReward = minDistance - nowDistanceToTarget;
SetReward(thisReward);
//Debug.Log("Rewards = " + thisReward);
}
else
{
thisReward = 0f;
SetReward(thisReward);
//Debug.Log("Rewards = " + thisReward);
}
}
}
// 键盘控制调试
public override void Heuristic(in ActionBuffers actionsOut)
{
//ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
ActionSegment<int> discreteActions = actionsOut.DiscreteActions;
int inpX = 0;
int inpY = 0;
if (Input.GetKey(KeyCode.W))
{
inpX = 1;
}
else if (Input.GetKey(KeyCode.S))
{
inpX = -1;
}
else
{
inpX = 0;
}
if (Input.GetKey(KeyCode.A))
{
inpY = 1;
}
else if (Input.GetKey(KeyCode.D))
{
inpY = -1;
}
else
{
inpY = 0;
}
discreteActions[0] = inpX;
discreteActions[1] = inpY;
}
}
+11
View File
@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: c76cbcf0bd8edb346b372986bec61e97
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant: