from os import mkdir
import tensorflow as tf
from tensorflow.python.ops.numpy_ops import ndarray
import tensorflow_probability as tfp
import numpy as np
import time
import math

import datetime
from PPOConfig import PPOConfig

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import optimizers

EPS = 1e-10


class PPO(object):
    def __init__(
        self,
        stateSize: int,
        disActShape: list,
        conActSize: int,
        conActRange: float,
        PPOConfig: PPOConfig,
    ):
        """initialize PPO

        Args:
            stateSize (int): enviroment state size
            disActShape (numpy): discrete Action shape.
                                just like [3,2],means 2 type of dis actions,each act include 3 and 2 types
                                if no discrete action output then use [0].
            conActSize (int): continuous Action Size. if no continuous action output then use 0.
            conActRange (float): continuous action range. -conActRange to +conActRange
            PPOConfig (PPOConfig): PPO configuration
        """
        # check use dis action or not.
        if disActShape == [0]:
            # non dis action output
            self.disActSize = 0
            self.disOutputSize = 0
        else:
            # make sure disActShape greater than 1
            try:
                if np.any(np.array(disActShape) <= 1):
                    raise ValueError(
                        "disActShape error,disActShape should greater than 1 but get", disActShape
                    )
            except ValueError:
                raise
            self.disActSize = len(disActShape)
            self.disOutputSize = sum(disActShape)

        self.stateSize = stateSize
        self.disActShape = disActShape
        self.conActSize = conActSize
        self.conActRange = conActRange
        self.muSigSize = 2
        self.conOutputSize = conActSize * self.muSigSize

        # config
        self.NNShape = PPOConfig.NNShape
        self.criticLR = PPOConfig.criticLR
        self.actorLR = PPOConfig.actorLR
        self.gamma = PPOConfig.gamma
        self.lmbda = PPOConfig.lmbda
        self.clipRange = PPOConfig.clipRange
        self.entropyWeight = PPOConfig.entropyWeight
        self.trainEpochs = PPOConfig.trainEpochs
        self.saveDir = PPOConfig.saveDir
        self.loadModelDir = PPOConfig.loadModelDir
        print("---------thisPPO Params---------")
        print("self.stateSize = ", self.stateSize)
        print("self.disActShape = ", self.disActShape)
        print("self.disActSize", self.disActSize)
        print("self.disOutputSize", self.disOutputSize)
        print("self.conActSize = ", self.conActSize)
        print("self.conActRange = ", self.conActRange)
        print("self.conOutputSize = ", self.conOutputSize)

        # config
        print("---------thisPPO config---------")
        print("self.NNShape = ", self.NNShape)
        print("self.criticLR = ", self.criticLR)
        print("self.actorLR = ", self.actorLR)
        print("self.gamma = ", self.gamma)
        print("self.lmbda = ", self.lmbda)
        print("self.clipRange = ", self.clipRange)
        print("self.entropyWeight = ", self.entropyWeight)
        print("self.trainEpochs = ", self.trainEpochs)
        print("self.saveDir = ", self.saveDir)
        print("self.loadModelDir = ", self.loadModelDir)

        # load NN or not
        if self.loadModelDir is None:
            # critc NN
            self.critic = self.buildCriticNet(self.stateSize, 1, compileModel=True)
            # actor NN
            self.actor = self.buildActorNet(self.stateSize, compileModel=True)
            print("---------Actor Model Create Success---------")
            self.actor.summary()
            print("---------Critic Model Create Success---------")
            self.critic.summary()
        else:
            # critc NN
            self.critic = self.buildCriticNet(self.stateSize, 1, compileModel=True)
            # actor NN
            self.actor = self.buildActorNet(self.stateSize, compileModel=True)
            # load weight to Critic&Actor NN
            self.loadWeightToModels(self.loadModelDir)
            print("---------Actor Model Load Success---------")
            self.actor.summary()
            print("---------Critic Model Load Success---------")
            self.critic.summary()

    # Build Net
    def buildActorNet(self, inputSize: int, compileModel: bool):
        """build Actor Nueral Net and compile.Output:[disAct1,disAct2,disAct3,mu,sigma]

        Args:
            inputSize (int): InputLayer Nueral size.
            compileModel (bool): compile Model or not.

        Returns:
            keras.Model: return Actor NN
        """
        # -----------Input Layers-----------
        stateInput = layers.Input(shape=(inputSize,), name="stateInput")

        # -------Intermediate layers--------
        interLayers = []
        interLayersIndex = 0
        for neuralUnit in self.NNShape:
            thisLayerName = "dense" + str(interLayersIndex)
            if interLayersIndex == 0:
                interLayers.append(
                    layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(stateInput)
                )
            else:
                interLayers.append(
                    layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(interLayers[-1])
                )
            interLayersIndex += 1

        # ----------Output Layers-----------
        outputLayersList = []
        if self.disActSize != 0:
            # while NN have discrete action output.
            disActIndex = 0
            for thisDisActDepth in self.disActShape:
                thisDisActName = "disAct" + str(disActIndex)
                outputLayersList.append(
                    layers.Dense(thisDisActDepth, activation="softmax", name=thisDisActName)(
                        interLayers[-1]
                    )
                )
                disActIndex += 1
        if self.conActSize != 0:
            # while NN have continuous action output.
            mu = tf.multiply(
                layers.Dense(1, activation="tanh", name="muOut")(interLayers[-1]), self.conActRange
            )  # mu,既正态分布位置参数
            sigma = tf.add(
                layers.Dense(1, activation="softplus", name="sigmaOut")(interLayers[-1]), EPS
            )  # sigma,既正态分布尺度参数
            outputLayersList.append(mu)
            outputLayersList.append(sigma)
        totalOut = layers.concatenate(outputLayersList, name="totalOut")  # package

        # ----------Model Compile-----------
        model = keras.Model(inputs=stateInput, outputs=totalOut)
        if compileModel:  # Compile Model
            actorOPT = optimizers.Adam(learning_rate=self.actorLR)
            model.compile(optimizer=actorOPT, loss=self.aLoss())
        return model

    def buildCriticNet(self, inputSize: int, outputSize: int, compileModel: bool):
        """build Critic Nueral Net and compile.Output:[Q]

        Args:
            inputSize (int): input size
            outputSize (int): output size
            compileModel (bool): compile Model or not.

        Returns:
            keras.Model: return Critic NN
        """
        # -----------Input Layers-----------
        stateInput = keras.Input(shape=(inputSize,), name="stateInput")

        # -------Intermediate layers--------
        interLayers = []
        interLayersIndex = 0
        for neuralUnit in self.NNShape:
            thisLayerName = "dense" + str(interLayersIndex)
            if interLayersIndex == 0:
                interLayers.append(
                    layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(stateInput)
                )
            else:
                interLayers.append(
                    layers.Dense(neuralUnit, activation="relu", name=thisLayerName)(interLayers[-1])
                )
            interLayersIndex += 1

        # ----------Output Layers-----------
        output = layers.Dense(outputSize, activation=None)(interLayers[-1])

        # ----------Model Compile-----------
        model = keras.Model(inputs=stateInput, outputs=output)
        if compileModel:
            criticOPT = optimizers.Adam(learning_rate=self.criticLR)
            model.compile(optimizer=criticOPT, loss=self.cLoss())
        return model

    # loss Function
    # critic loss
    def cLoss(self):
        """Critic Loss function"""

        def loss(y_true, y_pred):
            # y_true: discountedR
            # y_pred: critcV = model.predict(states)

            adv = y_true - y_pred  # TD error
            loss = tf.reduce_mean(tf.square(adv))
            return loss

        return loss

    # actor loss
    def aLoss(self):
        """Actor Loss function"""

        def getDiscreteALoss(nowProbs, oldProbs, disOneHotAct, actShape, advantage):
            """get Discrete Action Loss

            Args:
                nowProbs (tf.constant): (length,actionProbSize)
                oldProbs (tf.constant): (length,actionProbSize)
                advantage (tf.constant): (length,)

            Returns:
                tf.constant: (length,)
            """
            entropy = tf.negative(
                tf.reduce_mean(tf.math.multiply(nowProbs, tf.math.log(nowProbs + EPS)))
            )
            nowSingleProbs = tf.reduce_mean(tf.multiply(nowProbs, disOneHotAct), axis=1)
            nowSingleProbs = tf.multiply(nowSingleProbs, actShape)
            oldSingleProbs = tf.reduce_mean(tf.multiply(oldProbs, disOneHotAct), axis=1)
            oldSingleProbs = tf.multiply(oldSingleProbs, actShape)
            ratio = tf.math.divide(nowSingleProbs, oldSingleProbs + EPS)
            value = tf.math.multiply(ratio, advantage)
            clipRatio = tf.clip_by_value(ratio, 1.0 - self.clipRange, 1.0 + self.clipRange)
            clipValue = tf.math.multiply(clipRatio, advantage)
            loss = tf.math.negative(
                tf.reduce_mean(tf.math.minimum(value, clipValue))
                - tf.multiply(self.entropyWeight, entropy)
            )
            return loss

        def getContinuousALoss(musig, actions, oldProbs, advantage):
            """get Continuous Action Loss

            Args:
                musig (tf.constant): (length,2)
                actions (tf.constant): (length,)
                oldProbs (tf.constant): (length,)
                advantage (tf.constant): (length,)

            Returns:
                tf.constant: (length,)
            """
            mu = musig[:, 0]
            sigma = musig[:, 1]
            dist = tfp.distributions.Normal(mu, sigma)

            nowProbs = dist.prob(actions)
            entropy = tf.reduce_mean(dist.entropy())

            ratio = tf.math.divide(nowProbs, oldProbs + EPS)
            value = tf.math.multiply(ratio, advantage)
            clipRatio = tf.clip_by_value(ratio, 1.0 - self.clipRange, 1.0 + self.clipRange)
            clipValue = tf.math.multiply(clipRatio, advantage)
            loss = tf.negative(
                tf.reduce_mean(tf.math.minimum(value, clipValue))
                - tf.multiply(self.entropyWeight, entropy)
            )
            return loss

        def loss(y_true, y_pred):
            # y_true: [[disActProb..., conActProbs..., disOneHotActs..., conAct..., advantage]]
            # y_pred: [[disActProb..., mu, sigma...]]
            totalALoss = 0
            totalActionNum = 0
            advantage = y_true[:, -1]

            if self.disActSize != 0:
                # while NN have discrete action output.
                oldDisProbs = y_true[:, 0 : self.disOutputSize]
                nowDisProbs = y_pred[:, 0 : self.disOutputSize]  # [disAct1, disAct2, disAct3]
                disOneHotActs = y_true[
                    :,
                    self.disOutputSize
                    + self.conActSize : self.disOutputSize
                    + self.conActSize
                    + self.disOutputSize,
                ]
                lastDisActShape = 0
                for thisShape in self.disActShape:
                    thisNowDisProbs = nowDisProbs[:, lastDisActShape : lastDisActShape + thisShape]
                    thisOldDisProbs = oldDisProbs[:, lastDisActShape : lastDisActShape + thisShape]
                    thisDisOneHotActs = disOneHotActs[
                        :, lastDisActShape : lastDisActShape + thisShape
                    ]
                    discreteALoss = getDiscreteALoss(
                        thisNowDisProbs, thisOldDisProbs, thisDisOneHotActs, thisShape, advantage
                    )
                    lastDisActShape += thisShape
                    totalALoss += discreteALoss
                    totalActionNum += 1.0
            if self.conActSize != 0:
                # while NN have continuous action output.
                oldConProbs = y_true[:, self.disOutputSize : self.disOutputSize + self.conActSize]
                conActions = y_true[
                    :,
                    self.disOutputSize
                    + self.conActSize
                    + self.disOutputSize : self.disOutputSize
                    + self.conActSize
                    + self.disOutputSize
                    + self.conActSize,
                ]
                nowConMusigs = y_pred[:, self.disOutputSize :]  # [musig1,musig2]
                lastConAct = 0
                for conAct in range(self.conActSize):
                    thisNowConMusig = nowConMusigs[:, lastConAct : lastConAct + self.muSigSize]
                    thisOldConProb = tf.squeeze(oldConProbs[:, conAct : conAct + 1])
                    thisConAction = conActions[:, conAct]
                    continuousAloss = getContinuousALoss(
                        thisNowConMusig, thisConAction, thisOldConProb, advantage
                    )
                    totalALoss += continuousAloss
                    totalActionNum += 1.0
                    lastConAct += self.muSigSize
            loss = tf.divide(totalALoss, totalActionNum)
            return loss

        return loss

    # get Actions&values
    def chooseAction(self, state: ndarray):
        """Agent choose action to take

        Args:
            state (ndarray): enviroment state

        Returns:
            np.array:
                actions,
                    actions list,1dims like [0,1,1.5]
                predictResult,
                    actor NN predict Result output
        """
        # let actor choose action,use the normal distribution
        # state = np.expand_dims(state,0)

        # check state dimension is [stateNum,statesize]
        if state.ndim != 2:
            stateNum = int(len(state) / self.stateSize)
            state = state.reshape([stateNum, self.stateSize])
        predictResult = self.actor(state)  # get predict result [[disAct1, disAct2, disAct3, musig]]
        # print("predictResult",predictResult)
        # predictResult = predictResult.numpy()
        actions = []
        if self.disActSize != 0:
            # while NN have discrete action output.
            lastDisActShape = 0
            for shape in self.disActShape:
                thisDisActProbs = predictResult[:, lastDisActShape : lastDisActShape + shape]
                dist = tfp.distributions.Categorical(probs=thisDisActProbs, dtype=tf.float32)
                action = int(dist.sample().numpy()[0])
                # action = np.argmax(thisDisActProbs)
                actions.append(action)
                lastDisActShape += shape
        if self.conActSize != 0:
            # while NN have continuous action output.
            lastConAct = 0
            for actIndex in range(self.conActSize):
                thisMu = predictResult[:, self.disOutputSize + lastConAct]
                thisSig = predictResult[:, self.disOutputSize + lastConAct + 1]
                if math.isnan(thisMu) or math.isnan(thisSig):
                    # check mu or sigma is nan
                    print("chooseAction:mu or sigma is nan")
                    print(predictResult)
                thisDist = np.random.normal(loc=thisMu, scale=thisSig)
                actions.append(np.clip(thisDist, -self.conActRange, self.conActRange)[0])
                lastConAct += 2
        return actions, predictResult

    def trainCritcActor(
        self,
        states: ndarray,
        oldActorResult: ndarray,
        actions: ndarray,
        rewards: ndarray,
        dones: ndarray,
        nextState: ndarray,
        epochs: int = None,
    ):
        """train critic&actor use PPO ways

        Args:
            states (ndarray): states
            oldActorResult (ndarray): actor predict result
            actions (ndarray): predicted actions include both discrete actions and continuous actions
            rewards (ndarray): rewards from enviroment
            dones (ndarray): dones from enviroment
            nextState (ndarray): next state from enviroment
            epochs (int, optional): train epochs,default to ppoConfig. Defaults to None.

        Returns:
            tf.constant: criticLoss, actorLoss
        """

        if epochs == None:
            epochs = self.trainEpochs
        criticValues = self.getCriticV(state=states)
        discountedR = self.discountReward(nextState, criticValues, dones, rewards)
        advantage = self.getGAE(discountedR, criticValues)

        criticLoss = self.trainCritic(states, discountedR, epochs)
        actorLoss = self.trainActor(states, oldActorResult, actions, advantage, epochs)
        # print("A_Loss:", actorLoss, "C_Loss:", criticLoss)
        return actorLoss, criticLoss

    def trainCritic(self, states: ndarray, discountedR: ndarray, epochs: int = None):
        """critic NN trainning function

        Args:
            states (ndarray): states
            discountedR (ndarray): discounted rewards
            epochs (int, optional): train epochs,default to ppoConfig. Defaults to None.

        Returns:
            tf.constant: all critic losses
        """
        if epochs == None:
            epochs = self.trainEpochs
        his = self.critic.fit(x=states, y=discountedR, epochs=epochs, verbose=0)
        return his.history["loss"]

    def trainActor(
        self,
        states: ndarray,
        oldActorResult: ndarray,
        actions: ndarray,
        advantage: ndarray,
        epochs: int = None,
    ):
        """actor NN trainning function

        Args:
            states (ndarray): states
            oldActorResult (ndarray): actor predict results
            actions (ndarray): acotor predict actions
            advantage (ndarray): GAE advantage
            epochs (int, optional): train epochs,default to ppoConfig. Defaults to None.

        Returns:
            tf.constant: all actor losses
        """
        # Trian Actor
        # states: Buffer States
        # actions: Buffer Actions
        # discountedR: Discounted Rewards
        # Epochs: just Epochs
        if epochs == None:
            epochs = self.trainEpochs
        actions = np.asarray(actions, dtype=np.float32)

        disActions = actions[:, 0 : self.disActSize]
        conActions = actions[:, self.disActSize :]
        oldDisProbs = oldActorResult[:, 0 : self.disOutputSize]  # [disAct1, disAct2, disAct3]
        oldConMusigs = oldActorResult[:, self.disOutputSize :]  # [musig1,musig2]
        if self.disActSize != 0:
            disOneHotActs = self.getOneHotActs(disActions)
            if self.conActSize != 0:
                # while NN have discrete6 & continuous actions output.
                oldPiProbs = self.conProb(oldConMusigs[:, 0], oldConMusigs[:, 1], conActions)
                # pack [oldDisProbs,oldPiProbs,conActions,advantage] as y_true
                y_true = np.hstack((oldDisProbs, oldPiProbs, disOneHotActs, conActions, advantage))
            else:
                # while NN have only discrete actions output.
                # pack [oldDisProbs,advantage] as y_true
                y_true = np.hstack((oldDisProbs, disOneHotActs, advantage))
        else:
            if self.conActSize != 0:
                # while NN have only continuous action output.
                oldPiProbs = self.conProb(oldConMusigs[:, 0], oldConMusigs[:, 1], conActions)
                # pack [oldPiProbs,conActions,advantage] as y_true
                y_true = np.hstack((oldPiProbs, conActions, advantage))
            else:
                print("trainActor:disActSize & conActSize error")
                time.sleep(999999)
        # assembly Actions history
        # train start
        if np.any(tf.math.is_nan(y_true)):
            print("y_true got nan")
            print("y_true", y_true)
        his = self.actor.fit(x=states, y=y_true, epochs=epochs, verbose=0)
        if np.any(tf.math.is_nan(his.history["loss"])):
            print("his.history['loss'] is nan!")
            print(his.history["loss"])
        return his.history["loss"]

    def saveWeights(self, score: float):
        """save now NN's Weight. Use "models.save_weights" method.
        Save as "tf" format "ckpt" file.

        Args:
            score (float): now score
        """
        actor_save_dir = (
            self.saveDir + datetime.datetime.now().strftime("%H%M%S") + "/actor/" + "actor.ckpt"
        )
        critic_save_dir = (
            self.saveDir + datetime.datetime.now().strftime("%H%M%S") + "/critic/" + "critic.ckpt"
        )
        self.actor.save_weights(actor_save_dir, save_format="tf")
        self.critic.save_weights(critic_save_dir, save_format="tf")
        # create an empty file named  as score to recored score
        score_dir = (
            self.saveDir + datetime.datetime.now().strftime("%H%M%S") + "/" + str(round(score))
        )
        try:
            scorefile = open(score_dir, "w")
        except FileNotFoundError:
            mkdir(self.saveDir + datetime.datetime.now().strftime("%H%M%S") + "/")
            scorefile = open(score_dir, "w")
        scorefile.close()
        print("PPO Model's Weights Saved")

    def loadWeightToModels(self, loadDir: str):
        """load NN Model. Use "models.load_weights()" method.
        Load "tf" format "ckpt" file.

        Args:
            loadDir (str): Model dir
        """
        actorDir = loadDir + "/actor/" + "actor.ckpt"
        criticDir = loadDir + "/critic/" + "critic.ckpt"
        self.actor.load_weights(actorDir)
        self.critic.load_weights(criticDir)

        print("++++++++++++++++++++++++++++++++++++")
        print("++++++++++++Model Loaded++++++++++++")
        print(loadDir)
        print("++++++++++++++++++++++++++++++++++++")

    def getCriticV(self, state: ndarray):
        """get Critic predict V value

        Args:
            state (ndarray): Env state

        Returns:
            tensor: retrun Critic predict result
        """
        # if state.ndim < 2:
        #    state = np.expand_dims(state,0)
        if state.ndim != 2:
            stateNum = int(len(state) / self.stateSize)
            state = state.reshape([stateNum, self.stateSize])
        return self.critic.predict(state)

    def discountReward(self, nextState: ndarray, values: ndarray, dones: ndarray, rewards: ndarray):
        """Discount future rewards

        Args:
            nextState (ndarray): next Env state
            values (ndarray): critic predict values
            dones (ndarray): dones from enviroment
            rewards (ndarray): reward list of this episode

        Returns:
            ndarray: discounted rewards list,same shape as rewards that input
        """
        """
        nextV = self.getCriticV(nextState)
        dones = 1 - dones
        discountedRewards = []
        for i in reversed(range(len(rewards))):
            nextV = rewards[i] + dones[i] * self.gamma * nextV
            discountedRewards.append(nextV)
        discountedRewards.reverse()  # reverse
        discountedRewards = np.squeeze(discountedRewards)
        discountedRewards = np.expand_dims(discountedRewards, axis=1)
        # discountedRewards = np.array(discountedRewards)[:, np.newaxis]
        return discountedRewards
        """
        """
        nextV = self.getCriticV(nextState)
        discountedRewards = []
        for r in rewards[::-1]:
            nextV = r + self.gamma * nextV
            discountedRewards.append(nextV)
        discountedRewards.reverse()  # reverse
        discountedRewards = np.squeeze(discountedRewards)
        discountedRewards = np.expand_dims(discountedRewards, axis=1)
        # discountedRewards = np.array(discountedRewards)[:, np.newaxis]
        print(discountedRewards)
        return discountedRewards
        """
        g = 0
        discountedRewards = []
        lastValue = self.getCriticV(nextState)
        values = np.append(values, lastValue, axis=0)
        dones = 1 - dones
        for i in reversed(range(len(rewards))):
            delta = rewards[i] + self.gamma * values[i + 1] * dones[i] - values[i]
            g = delta + self.gamma * self.lmbda * dones[i] * g
            discountedRewards.append(g + values[i])
        discountedRewards.reverse()
        return np.asarray(discountedRewards)

    def getGAE(self, discountedRewards: ndarray, values: ndarray):
        """compute GAE adcantage

        Args:
            discountedRewards (ndarray): discounted rewards
            values (ndarray): critic predict values

        Returns:
            ndarray: GAE advantage
        """
        advantage = discountedRewards - values
        advantage = (advantage - np.mean(advantage)) / (np.std(advantage) + EPS)
        return advantage

    def conProb(self, mu: ndarray, sig: ndarray, x: ndarray):
        """calculate probability when x in Normal distribution(mu,sigma)

        Args:
            mu (ndarray): mu
            sig (ndarray): sigma
            x (ndarray): x

        Returns:
            ndarray: probability
        """
        # 获取在正态分布mu,sig下当取x值时的概率
        # return shape : (length,1)
        mu = np.reshape(mu, (np.size(mu),))
        sig = np.reshape(sig, (np.size(sig),))
        x = np.reshape(x, (np.size(x),))

        dist = tfp.distributions.Normal(mu, sig)
        prob = dist.prob(x)

        prob = np.reshape(prob, (np.size(x), 1))
        # dist = 1./(tf.sqrt(2.*np.pi)*sig)
        # prob = dist*tf.exp(-tf.square(x-mu)/(2.*tf.square(sig)))
        return prob

    def getOneHotActs(self, disActions):
        """one hot action encoder

        Args:
            disActions (ndarray): discrete actions

        Returns:
            ndarray: one hot actions
        """
        actIndex = 0
        for thisShape in self.disActShape:
            thisActs = disActions[:, actIndex]
            thisOneHotAct = tf.squeeze(tf.one_hot(thisActs, thisShape)).numpy()
            if actIndex == 0:
                oneHotActs = thisOneHotAct
            else:
                oneHotActs = np.append(oneHotActs, thisOneHotAct, axis=1)
            actIndex += 1
        return oneHotActs

    def getAverageEntropy(self, probs: ndarray):
        """get average dis&con ACT Entropys

        Args:
            probs (ndarray): actor NN predict result

        Returns:
            float: average total entropy
            list: discrete entropys
            list: continuous entropys
        """
        discreteEntropys = []
        continuousEntropys = []
        if self.disActSize != 0:
            disProbs = probs[:, 0 : self.disOutputSize]
            lastDisActIndex = 0
            for actShape in self.disActShape:
                thisDisProbs = disProbs[:, lastDisActIndex : lastDisActIndex + actShape]
                lastDisActIndex += actShape
                discreteEntropys.append(
                    tf.negative(
                        tf.reduce_mean(
                            tf.math.multiply(thisDisProbs, tf.math.log(thisDisProbs + EPS))
                        )
                    )
                )
        if self.conActSize != 0:
            conProbs = probs[:, self.disOutputSize :]
            conActIndex = 0
            for i in range(self.conActSize):
                thisConProbs = conProbs[:, conActIndex : conActIndex + 2]
                conActIndex += 2
                continuousEntropys.append(tf.reduce_mean(thisConProbs[:, 1]))
        averageEntropy = np.mean([np.mean(discreteEntropys), np.mean(continuousEntropys)])
        return averageEntropy, discreteEntropys, continuousEntropys