using System;
using System.Collections.Generic;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using UnityEngine;

public class MLAgentsCustomController : Agent
{
    [SerializeField] private GameObject paramContainerObj;
    [SerializeField] private GameObject targetControllerObj;
    [SerializeField] private GameObject environmentUIObj;
    [SerializeField] private GameObject sideChannelObj;
    [SerializeField] private GameObject worldUIControllerObj;
    [SerializeField] private GameObject hudObj;

    // script
    private AgentController agentController;

    private ParameterContainer paramContainer;
    private CommonParameterContainer commonParamCon;
    private TargetController targetController;
    private EnvironmentUIControl envUIController;
    private HUDController hudController;
    private TargetUIController targetUIController;
    private RaySensors raySensors;
    private MessageBoxController messageBoxController;
    private AimBotSideChannelController sideChannelController;
    private WorldUIController worldUICon;
    private RewardFunction rewardFunction;

    // observation
    private float[] myObserve = new float[5];

    private float[] rayTagResult;
    private float[] rayTagResultOnehot;
    private float[] rayDisResult;
    private float remainTime;
    private float inFireBaseState;

    private int endTypeInt;

    private void Start()
    {
        agentController = transform.GetComponent<AgentController>();
        raySensors = transform.GetComponent<RaySensors>();
        paramContainer = paramContainerObj.GetComponent<ParameterContainer>();
        commonParamCon = CommonParameterContainer.Instance;
        targetController = targetControllerObj.GetComponent<TargetController>();
        envUIController = environmentUIObj.GetComponent<EnvironmentUIControl>();
        hudController = hudObj.GetComponent<HUDController>();
        targetUIController = hudObj.GetComponent<TargetUIController>();
        messageBoxController = hudObj.GetComponent<MessageBoxController>();
        sideChannelController = sideChannelObj.GetComponent<AimBotSideChannelController>();
        rewardFunction = gameObject.GetComponent<RewardFunction>();
        worldUICon = worldUIControllerObj.GetComponent<WorldUIController>();
    }

    public override void OnEpisodeBegin()
    {
        agentController.UpdateLockMouse();
        paramContainer.ResetTimeBonusReward();
        if (commonParamCon.gameMode == 0)
        {
            // train mode
            Debug.Log("MLAgentCustomController.OnEpisodeBegin: train mode start");
            targetController.RollNewScene();
        }
        else
        {
            Debug.Log("MLAgentCustomController.OnEpisodeBegin: play mode start");
            // play mode
            targetController.PlayModeInitialize();
            // reset target UI
            targetUIController.ClearGamePressed();
        }

        // give default Reward to Reward value will be used.
        if (hudController.chartOn)
        {
            envUIController.InitChart();
        }
        raySensors.UpdateRayInfo(); // update raycast
    }

    public override void CollectObservations(VectorSensor sensor)
    {
        //List<float> enemyLDisList = RaySensors.enemyLDisList;// All Enemy Lside Distances
        //List<float> enemyRDisList = RaySensors.enemyRDisList;// All Enemy Rside Distances
        /**myObserve[0] = transform.localPosition.x / raySensors.viewDistance;
        myObserve[1] = transform.localPosition.y / raySensors.viewDistance;
        myObserve[2] = transform.localPosition.z / raySensors.viewDistance;
        myObserve[3] = transform.eulerAngles.y / 360f;**/
        float obsNum = 0f;
        float angleInRadians = transform.eulerAngles.y * Mathf.Deg2Rad;
        myObserve[0] = transform.localPosition.x;
        myObserve[1] = transform.localPosition.y;
        myObserve[2] = transform.localPosition.z;
        myObserve[3] = MathF.Sin(angleInRadians);
        myObserve[4] = MathF.Cos(angleInRadians);
        rayTagResult = raySensors.rayTagResult;// 探测用RayTag类型结果 float[](raySensorNum,1)
        rayTagResultOnehot = raySensors.rayTagResultOneHot; // 探测用RayTagonehot结果 List<int>[](raySensorNum*Tags,1)
        rayDisResult = raySensors.rayDisResult; // 探测用RayDis距离结果 float[](raySensorNum,1)
        remainTime = targetController.leftTime;
        inFireBaseState = targetController.GetInAreaState();
        agentController.UpdateGunState();
        //float[] focusEnemyObserve = RaySensors.focusEnemyInfo;// 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z

        //sensor.AddObservation(allEnemyNum); // 敌人数量 int
        sensor.AddObservation(targetController.targetState);// (5) targettype, target x,y,z, firebasesAreaDiameter
        sensor.AddObservation(inFireBaseState); // (1)
        sensor.AddObservation(remainTime); // (1)
        sensor.AddObservation(agentController.gunReadyToggle); // (1) save gun is ready?
        sensor.AddObservation(myObserve); // (5)自机位置xyz+朝向 float[](5,1)
        // count observation number
        obsNum = targetController.targetState.Length+1+1+1+myObserve.Length;
        Debug.Log(obsNum);
        if (commonParamCon.oneHotRayTag)
        {
            sensor.AddObservation(rayTagResultOnehot); // 探测用RayTag结果 float[](raySensorNum,1)
            obsNum += rayTagResultOnehot.Length;
        }
        else
        {
            sensor.AddObservation(rayTagResult);
            obsNum += rayTagResult.Length;
        }
        Debug.Log(obsNum);
        sensor.AddObservation(rayDisResult); // 探测用RayDis距离结果 float[](raySensorNum,1)
        obsNum += rayDisResult.Length;
        envUIController.UpdateStateText(targetController.targetState, inFireBaseState, remainTime, agentController.gunReadyToggle, myObserve, rayTagResultOnehot, rayDisResult);
        Debug.Log(obsNum);
        /*foreach(float aaa in rayDisResult)
        {
            Debug.Log(aaa);
        }
        Debug.LogWarning("------------");*/
        //sensor.AddObservation(focusEnemyObserve); // 最近的Enemy情报 float[](3,1) MinEnemyIndex,x,z
        //sensor.AddObservation(raySensorNum); // raySensor数量 int
        //sensor.AddObservation(remainTime); // RemainTime int
    }

    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        //获取输入
        int vertical = actionBuffers.DiscreteActions[0];
        int horizontal = actionBuffers.DiscreteActions[1];
        int mouseShoot = actionBuffers.DiscreteActions[2];
        float Mouse_X = actionBuffers.ContinuousActions[0];
        if (vertical == 2) vertical = -1;
        if (horizontal == 2) horizontal = -1;

        //应用输入
        agentController.CameraControl(Mouse_X, 0);
        agentController.MoveAgent(vertical, horizontal);
        raySensors.UpdateRayInfo(); // update raycast

        //判断结束
        float sceneReward = 0f;
        float endReward = 0f;
        (endTypeInt, sceneReward, endReward) = rewardFunction.CheckOverAndRewards();
        float nowReward = rewardFunction.RewardCalculate(sceneReward + endReward, Mouse_X, Math.Abs(vertical) + Math.Abs(horizontal), mouseShoot);
        if (hudController.chartOn)
        {
            envUIController.UpdateChart(nowReward);
        }
        else
        {
            envUIController.RemoveChart();
        }
        worldUICon.UpdateChart(targetController.targetType, endTypeInt);
        //Debug.Log("reward = " + nowReward);
        if (endTypeInt != (int)TargetController.EndType.Running)
        {
            // Win or lose Finished
            Debug.Log("Finish reward = " + nowReward);
            string targetString = Enum.GetName(typeof(Targets), targetController.targetType);
            switch (endTypeInt)
            {
                case (int)TargetController.EndType.Win:
                    sideChannelController.SendSideChannelMessage("Result", targetString + "|Win");
                    messageBoxController.PushMessage(
                        new List<string> { "Game Win" },
                        new List<string> { "green" });
                    break;

                case (int)TargetController.EndType.Lose:
                    sideChannelController.SendSideChannelMessage("Result", targetString + "|Lose");
                    messageBoxController.PushMessage(
                        new List<string> { "Game Lose" },
                        new List<string> { "red" });
                    break;

                default:
                    Debug.LogWarning("TypeError");
                    break;
            }
            SetReward(nowReward);
            EndEpisode();
        }
        else
        {
            // game not over yet
        }
        SetReward(nowReward);
    }

    public override void Heuristic(in ActionBuffers actionsOut)
    {
        //-------------------BUILD
        ActionSegment<float> continuousActions = actionsOut.ContinuousActions;
        ActionSegment<int> discreteActions = actionsOut.DiscreteActions;

        if (Input.GetKey(KeyCode.W) && !Input.GetKey(KeyCode.S))
        {
            discreteActions[0] = 1;
        }
        else if (Input.GetKey(KeyCode.S) && !Input.GetKey(KeyCode.W))
        {
            discreteActions[0] = -1;
        }
        else
        {
            discreteActions[0] = 0;
        }
        if (Input.GetKey(KeyCode.D) && !Input.GetKey(KeyCode.A))
        {
            discreteActions[1] = 1;
        }
        else if (Input.GetKey(KeyCode.A) && !Input.GetKey(KeyCode.D))
        {
            discreteActions[1] = -1;
        }
        else
        {
            discreteActions[1] = 0;
        }

        if (Input.GetMouseButton(0))
        {
            // Debug.Log("mousebuttonhit");
            discreteActions[2] = 1;
        }
        else
        {
            discreteActions[2] = 0;
        }
        //^^^^^^^^^^^^^^^^^^^^^discrete-Control^^^^^^^^^^^^^^^^^^^^^^

        //vvvvvvvvvvvvvvvvvvvvvvvvvvvvvcontinuous-Controlvvvvvvvvvvvvvvvvvvvvvv
        float Mouse_X = Input.GetAxis("Mouse X") * agentController.mouseXSensitivity * Time.deltaTime;
        //float Mouse_Y = Input.GetAxis("Mouse Y") * agentController.mouseYSensitivity * Time.deltaTime;
        continuousActions[0] = Mouse_X;
        //continuousActions[1] = nonReward;
        //continuousActions[2] = shootReward;
        //continuousActions[3] = shootWithoutReadyReward;
        //continuousActions[4] = hitReward;
        //continuousActions[5] = winReward;
        //continuousActions[6] = loseReward;
        //continuousActions[7] = killReward;
        //continuousActions[1] = Mouse_Y;
        //continuousActions[2] = timeLimit;
        //^^^^^^^^^^^^^^^^^^^^^^^^^^^^^continuous-Control^^^^^^^^^^^^^^^^^^^^^^
    }
}