V2.6 change Ray Tag result as Onehot

change Ray Tag result as Onehot, 
Observation State:
-targetStates 6
-inTargetArea 1
-remainTime 1
-gunReady 1
-my Obs 4
-tag onehot 19*2
-tag dis 19
This commit is contained in:
2022-12-10 10:14:44 +09:00
parent 9b2ba7fb46
commit 64ada808de
12 changed files with 211 additions and 144 deletions
+13 -3
View File
@@ -42,6 +42,7 @@ public class AgentWithGun : Agent
public float yRotation = 0.1f;//定义一个浮点类型的量,记录‘围绕’X轴旋转的角度
[Header("Env")]
public bool oneHotRayTag = true;
private List<float> spinRecord = new List<float>();
private bool lockMouse;
private float Damage;
@@ -418,6 +419,7 @@ public class AgentWithGun : Agent
{
EnvUICon.initChart();
}
rayScript.updateRayInfo(); // update raycast
}
// ML-AGENTS处理-------------------------------------------------------------------------------------------ML-AGENTS
// 观察情报
@@ -426,11 +428,13 @@ public class AgentWithGun : Agent
//List<float> enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances
//List<float> enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances
float[] myObserve = { transform.localPosition.x, transform.localPosition.y, transform.localPosition.z, transform.eulerAngles.y };
float[] myObserve = { transform.localPosition.x, transform.localPosition.y, transform.localPosition.z, transform.eulerAngles.y/360f };
float[] rayTagResult = rayScript.rayTagResult;// 探测用RayTag结果 float[](raySensorNum,1)
float[] rayTagResultOnehot = rayScript.rayTagResultOneHot.ToArray(); // 探测用RayTagonehot结果 List<int>[](raySensorNum*Tags,1)
float[] rayDisResult = rayScript.rayDisResult; // 探测用RayDis结果 float[](raySensorNum,1)
float[] targetStates = targetCon.targetState; // (6) targettype, target x,y,z, firebasesAreaDiameter
float remainTime = targetCon.leftTime;
gunReadyToggle = gunReady();
//float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(allEnemyNum); // 敌人数量 int
@@ -439,11 +443,17 @@ public class AgentWithGun : Agent
sensor.AddObservation(remainTime); // (1)
sensor.AddObservation(gunReadyToggle); //(1) save gun is ready?
sensor.AddObservation(myObserve); // (4)自机位置xyz+朝向 float[](4,1)
sensor.AddObservation(rayTagResult); // 探测用RayTag结果 float[](raySensorNum,1)
if (oneHotRayTag)
{
sensor.AddObservation(rayTagResultOnehot); // 探测用RayTag结果 float[](raySensorNum,1)
}
else
{
sensor.AddObservation(rayTagResult);
}
sensor.AddObservation(rayDisResult); // 探测用RayDis结果 float[](raySensorNum,1)
//sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
//sensor.AddObservation(raySensorNum); // raySensor数量 int
gunReadyToggle = gunReady();
//sensor.AddObservation(remainTime); // RemainTime int
}
@@ -0,0 +1,34 @@
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.SideChannels;
public class AimBotSideChennelController : MonoBehaviour
{
public AimbotSideChannel aimbotSideChannel;
public void Awake()
{
// We create the Side Channel
aimbotSideChannel = new AimbotSideChannel();
// When a Debug.Log message is created, we send it to the stringChannel
Application.logMessageReceived += aimbotSideChannel.SendDebugStatementToPython;
// The channel must be registered with the SideChannelManager class
SideChannelManager.RegisterSideChannel(aimbotSideChannel);
}
// Side Channel
public void OnDestroy()
{
// De-register the Debug.Log callback
Application.logMessageReceived -= aimbotSideChannel.SendDebugStatementToPython;
if (Academy.IsInitialized)
{
SideChannelManager.UnregisterSideChannel(aimbotSideChannel);
}
}
}
@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: ea781484763623c438c1806e3a965667
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
+60
View File
@@ -0,0 +1,60 @@
using System;
using System.Collections;
using System.Collections.Generic;
using TMPro;
using UnityEngine;
using UnityEngine.UI;
public class HUDController : MonoBehaviour
{
public bool chartOn = false;
public Toggle chartOnToggleObj;
public TMP_InputField chartOnTimeOutInputObj;
public TMP_InputField enemyNumInputObj;
public float chartOnTimeOut = 1;
public int enemyNum= 3;
public float chartOnTimeOutDefault = 120f;
private float chatOntimeStart = 0;
private void Update()
{
if (chartOn)
{
if (Time.time - chatOntimeStart >= chartOnTimeOut )
{
chartOn = false;
chartOnToggleObj.isOn = false;
}
}
}
public void onChartOnToggleChange()
{
chatOntimeStart = Time.time;
chartOn = chartOnToggleObj.isOn;
}
public void onEnemyNumTextChange()
{
try
{
enemyNum = Math.Abs(int.Parse(enemyNumInputObj.GetComponent<TMP_InputField>().text));
}
catch (NullReferenceException)
{
enemyNum = 3;
}
}
public void onChartTimeOutTextChange()
{
try
{
chartOnTimeOut = Math.Abs(int.Parse(chartOnTimeOutInputObj.GetComponent<TMP_InputField>().text));
}
catch (NullReferenceException)
{
chartOnTimeOut = chartOnTimeOutDefault;
}
}
}
@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 10f34b6fb217eff4e9be2bfe7044f132
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
+50
View File
@@ -0,0 +1,50 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;
using UnityEngine;
public class Onehot
{
private List<string> tags = new List<string>();
public List<List<float>> onehot = new List<List<float>>();
private float totalNum;
public void initialize(List<string> inputTags)
{
tags = inputTags;
totalNum = tags.Count;
for (int i = 0; i < totalNum; i++)
{
List<float> thisOnehot = new List<float>();
for (int j = 0; j < totalNum; j++) thisOnehot.Add(0f);
thisOnehot[i] = 1f;
onehot.Add(thisOnehot);
}
}
public List<float> encoder(string name = null)
{
if (name == null)
{
List<float> allZeroOnehot = new List<float>();
for (int j = 0; j < totalNum; j++) allZeroOnehot.Add(0);
return allZeroOnehot;
}
else
{
try
{
return onehot[tags.IndexOf(name)];
}catch(ArgumentOutOfRangeException)
{
List<float> allZeroOnehot = new List<float>();
for (int j = 0; j < totalNum; j++) allZeroOnehot.Add(0);
return allZeroOnehot;
}
}
}
public string decoder(List<float> thisOnehot)
{
return tags[onehot.IndexOf(thisOnehot)];
}
}
+11
View File
@@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: cdb75d03525930d4caaefab4eaaf6e8a
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
+11 -2
View File
@@ -1,6 +1,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
/*该scrip用于创建复数条ray于视角内,并探测被ray射到的物体*/
@@ -28,6 +29,7 @@ public class RaySensors : MonoBehaviour
[Header("RayCastResult")]
public float[] rayTagResult;
public List<float> rayTagResultOneHot;
public float[] rayDisResult;
[System.NonSerialized] public int totalRayNum;
@@ -37,6 +39,8 @@ public class RaySensors : MonoBehaviour
LineRenderer[] lineRenderers;
rayInfoUI[] rayInfoUIs;
public List<GameObject> inViewEnemies = new List<GameObject>();
private List<string> tags = new List<string> {"Wall","Enemy"};
private Onehot oneHotTags = new Onehot();
private void Start()
@@ -49,7 +53,8 @@ public class RaySensors : MonoBehaviour
lineRenderers = new LineRenderer[totalRayNum];
rayInfoOBJ = new GameObject[totalRayNum];
rayInfoUIs = new rayInfoUI[totalRayNum];
for(int i = 0; i < totalRayNum; i++)
oneHotTags.initialize(tags);
for (int i = 0; i < totalRayNum; i++)
{
linesOBJ[i] = new GameObject();
linesOBJ[i].name = "rayCastLine-" + Convert.ToString(i);
@@ -91,10 +96,12 @@ public class RaySensors : MonoBehaviour
if (Physics.Raycast(ray, out thisHit, viewDistance)) // 若在viewDistance范围内有碰撞
{
rayInfoText = thisHit.collider.tag;
rayTagResult = tagToInt(thisHit.collider.tag);
rayTagResult = tagToInt(rayInfoText);
rayTagResultOneHot.AddRange(oneHotTags.encoder(rayInfoText));
rayDisResult = thisHit.distance;
lineLength = rayDisResult;
rayInfoText += "\n" + Convert.ToString(rayDisResult);
rayDisResult = rayDisResult / viewDistance; // Normalization!
//输出log
switch (rayTagResult)
{
@@ -115,6 +122,7 @@ public class RaySensors : MonoBehaviour
}
else // 若在viewDistance范围无碰撞
{
rayTagResultOneHot.AddRange(oneHotTags.encoder());
rayTagResult = -1f;
rayDisResult = -1f;
//输出log
@@ -165,6 +173,7 @@ public class RaySensors : MonoBehaviour
float focusREdge = agentCam.pixelWidth * (1 + focusRange) / 2;
float thisCamPixelHeight = agentCam.pixelHeight;
inViewEnemies.Clear();
rayTagResultOneHot.Clear();
for (int i = 0; i < halfOuterRayNum; i++) // create left outside rays; 0 ~ focusLeftEdge
{
Vector3 point = new Vector3(i * focusLEdge / (halfOuterRayNum - 1), thisCamPixelHeight / 2, 0);