import gym
import numpy as np
import uuid
import airecorder
from numpy import ndarray
from mlagents_envs.base_env import ActionTuple
from mlagents_envs.environment import UnityEnvironment
from typing import Tuple, List
from mlagents_envs.side_channel.side_channel import (
    SideChannel,
    IncomingMessage,
    OutgoingMessage,
)
from arguments import set_save_model


class Aimbot(gym.Env):
    def __init__(
            self,
            env_path: str,
            worker_id: int = 1,
            base_port: int = 100,
            side_channels: list = []
    ):
        super(Aimbot, self).__init__()
        self.env = UnityEnvironment(
            file_name=env_path,
            seed=1,
            side_channels=side_channels,
            worker_id=worker_id,
            base_port=base_port,
        )
        self.env.reset()
        # all behavior_specs
        self.unity_specs = self.env.behavior_specs
        #  environment behavior name
        self.unity_beha_name = list(self.unity_specs)[0]
        #  environment behavior spec
        self.unity_specs = self.unity_specs[self.unity_beha_name]
        #  environment observation_space
        self.unity_obs_specs = self.unity_specs.observation_specs[0]
        #  environment action specs
        self.unity_action_spec = self.unity_specs.action_spec
        #  environment sample observation
        decision_steps, _ = self.env.get_steps(self.unity_beha_name)

        # OBSERVATION SPECS
        #  environment state shape. like tuple:(93,)
        self.unity_observation_shape = self.unity_obs_specs.shape

        # ACTION SPECS
        #  environment continuous action number. int
        self.unity_continuous_size = self.unity_action_spec.continuous_size
        #  environment discrete action shapes. list (3,3,2)
        self.unity_discrete_branches = self.unity_action_spec.discrete_branches
        #  environment discrete action type. int 3
        self.unity_discrete_type = self.unity_action_spec.discrete_size
        # environment discrete action type. int 3+3+2=8
        self.unity_discrete_size = sum(self.unity_discrete_branches)
        # environment total action size. int 3+2=5
        self.unity_action_size = self.unity_discrete_type + self.unity_continuous_size
        # ActionExistBool
        self.unity_dis_act_exist = self.unity_discrete_type != 0
        self.unity_con_act_exist = self.unity_continuous_size != 0

        # AGENT SPECS
        # all agents ID
        self.unity_agent_IDS = decision_steps.agent_id
        # agents number
        self.unity_agent_num = len(self.unity_agent_IDS)

        # all zero action
        self.all_zero_action = np.zeros((self.unity_agent_num, self.unity_action_size))

    def reset(self) -> Tuple[np.ndarray, List, List]:
        """reset environment and get observations

        Returns:
            ndarray: next_state, reward, done, loadDir, saveNow
        """
        # reset env
        self.env.reset()
        next_state, reward, done = self.get_steps()
        return next_state, reward, done

    # TODO:
    # delete all stack state DONE
    # get-step State disassembly function DONE
    # delete agent selection function DONE
    # self.step action wrapper function DONE
    def step(
            self,
            actions: ndarray,
    ) -> Tuple[np.ndarray, List, List]:
        """change actions list to ActionTuple then send it to environment

        Args:
            actions (ndarray): PPO chooseAction output action list.(agentNum,actionNum)

        Returns:
            ndarray: nextState, reward, done
        """
        # take action to environment
        # return mextState,reward,done
        # discrete action
        if self.unity_dis_act_exist:
            # create discrete action from actions list
            discrete_actions = actions[:, 0: self.unity_discrete_type]
        else:
            # create empty discrete action
            discrete_actions = np.asarray([[0]])
        # continuous action
        if self.unity_con_act_exist:
            # create continuous actions from actions list
            continuous_actions = actions[:, self.unity_discrete_type:]
        else:
            # create empty continuous action
            continuous_actions = np.asanyarray([[0.0]])

        # Dummy continuous action
        # continuousActions = np.asanyarray([[0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]])
        # create actionTuple
        this_action_tuple = ActionTuple(continuous=continuous_actions, discrete=discrete_actions)
        # take action to env
        self.env.set_actions(behavior_name=self.unity_beha_name, action=this_action_tuple)
        self.env.step()
        # get nextState & reward & done after this action
        next_states, rewards, dones = self.get_steps()
        return next_states, rewards, dones

    def get_steps(self) -> Tuple[np.ndarray, List, List]:
        """get environment now observations.
        Include State, Reward, Done

        Args:

        Returns:
            ndarray: nextState, reward, done
        """
        # get nextState & reward & done
        decision_steps, terminal_steps = self.env.get_steps(self.unity_beha_name)
        next_states = []
        dones = []
        rewards = []
        for this_agent_ID in self.unity_agent_IDS:
            # while Episode over agentID will both in decisionSteps and terminalSteps.
            # avoid redundant state and reward,
            # use agentExist toggle to check if agent is already exist.
            agent_exist = False
            # game done
            if this_agent_ID in terminal_steps:
                next_states.append(terminal_steps[this_agent_ID].obs[0])
                dones.append(True)
                rewards.append(terminal_steps[this_agent_ID].reward)
                agent_exist = True
            # game not over yet and agent not in terminalSteps
            if (this_agent_ID in decision_steps) and (not agent_exist):
                next_states.append(decision_steps[this_agent_ID].obs[0])
                dones.append(False)
                rewards.append(decision_steps[this_agent_ID].reward)

        return np.asarray(next_states), rewards, dones

    def close(self):
        self.env.close()


class AimbotSideChannel(SideChannel):
    def __init__(self, channel_id: uuid.UUID) -> None:
        super().__init__(channel_id)

    def on_message_received(self, msg: IncomingMessage) -> None:
        """
        Note: We must implement this method of the SideChannel interface to
        receive messages from Unity
        Message will be sent like this:
        "Warning|Message1|Message2|Message3" or
        "Error|Message1|Message2|Message3"
        """
        this_message_Original = msg.read_string()
        this_message = this_message_Original.split("|")
        print(this_message)
        if this_message[0] == "Warning":
            if this_message[1] == "Result":
                airecorder.total_rounds[this_message[2]] += 1
                if this_message[3] == "Win":
                    airecorder.win_rounds[this_message[2]] += 1
                # print(TotalRounds)
                # print(WinRounds)
            if this_message[1] == "Command":
                set_save_model(True)
                print("Command: " + this_message_Original)
        elif this_message[0] == "Error":
            print(this_message_Original)
        # # while Message type is Warning
        # if(thisResult[0] == "Warning"):
        #     # while Message1 is result means one game is over
        #     if (thisResult[1] == "Result"):
        #         TotalRounds[thisResult[2]]+=1
        #         # while Message3 is Win means this agent win this game
        #         if(thisResult[3] == "Win"):
        #             WinRounds[thisResult[2]]+=1
        #     # while Message1 is GameState means this game is just start
        #     # and tell python which game mode is
        #     elif (thisResult[1] == "GameState"):
        #         SCrecieved = 1
        # # while Message type is Error
        # elif(thisResult[0] == "Error"):
        #     print(thisMessage)

    # 发送函数
    def send_string(self, data: str) -> None:
        # send a string toC#
        msg = OutgoingMessage()
        msg.write_string(data)
        super().queue_message_to_send(msg)

    def send_bool(self, data: bool) -> None:
        msg = OutgoingMessage()
        msg.write_bool(data)
        super().queue_message_to_send(msg)

    def send_int(self, data: int) -> None:
        msg = OutgoingMessage()
        msg.write_int32(data)
        super().queue_message_to_send(msg)

    def send_float(self, data: float) -> None:
        msg = OutgoingMessage()
        msg.write_float32(data)
        super().queue_message_to_send(msg)

    def send_float_list(self, data: List[float]) -> None:
        msg = OutgoingMessage()
        msg.write_float32_list(data)
        super().queue_message_to_send(msg)