import os
import random
import numpy as np


class GAILMem(object):
    def __init__(self):
        self.states = []
        self.actorProbs = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.memNum = 0
        print("√√√√√Buffer Initialized Success√√√√√")

    def clearMem(self):
        """clearMemories"""
        self.states = []
        self.actorProbs = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.memNum = 0

    def saveMemtoFile(self, dir: str):
        """save memories ndarray to npz file

        Args:
            dir (str): save direction,like"GAIL-Expert-Data/",end with "/"
        """
        statesNP = np.asarray(self.states)
        actorProbsNP = np.asarray(self.actorProbs)
        actionsNP = np.asarray(self.actions)
        rewardsNP = np.asarray(self.rewards)
        donesNP = np.asarray(self.dones)
        thisSaveDir = dir + "pack-" + str(self.memNum)
        try:
            np.savez(
                thisSaveDir,
                states=statesNP,
                actorProbs=actorProbsNP,
                actions=actionsNP,
                rewards=rewardsNP,
                dones=donesNP,
            )
        except FileNotFoundError:
            os.mkdir(dir)
            np.savez(
                thisSaveDir,
                states=statesNP,
                actorProbs=actorProbsNP,
                actions=actionsNP,
                rewards=rewardsNP,
                dones=donesNP,
            )

    def loadMemFile(self, dir: str):
        """load memories from mpz file

        Args:
            dir (str): file direction
        """
        self.clearMem()
        memFile = np.load(dir, allow_pickle=True)
        self.states = memFile["states"].tolist()
        self.actorProbs = memFile["actorProbs"].tolist()
        self.actions = memFile["actions"].tolist()
        self.rewards = memFile["rewards"].tolist()
        self.dones = memFile["dones"].tolist()
        self.memNum = len(self.states)

    def getRandomSample(self, sampleNum: int = 0):
        """get random unique sample set.

        Args:
            sampleNum (int, optional): sample number, while 0 return all samples. Defaults to 0.

        Returns:
            tuple: (states,actorProbs,actions,rewards,dones)
        """
        if sampleNum == 0:
            return (
                self.getStates(),
                self.getActorProbs(),
                self.getActions(),
                self.getRewards(),
                self.getDones(),
            )
        else:
            randIndex = random.sample(range(0, self.memNum), sampleNum)
            return (
                self.standDims(np.asarray(self.states)[randIndex]),
                self.standDims(np.asarray(self.actorProbs)[randIndex]),
                self.standDims(np.asarray(self.actions)[randIndex]),
                self.standDims(np.asarray(self.rewards)[randIndex]),
                self.standDims(np.asarray(self.dones)[randIndex]),
            )

    def getStates(self):
        """get all States data as ndarray

        Returns:
            ndarray: ndarray type State data
        """
        return self.standDims(np.asarray(self.states))

    def getActorProbs(self):
        """get all ActorProbs data as ndarray

        Returns:
            ndarray: ndarray type ActorProbs data
        """

        return self.standDims(np.asarray(self.actorProbs))

    def getActions(self):
        """get all Actions data as ndarray

        Returns:
            ndarray: ndarray type Actions data
        """

        return self.standDims(np.asarray(self.actions))

    def getRewards(self):
        """get all Rewards data as ndarray

        Returns:
            ndarray: ndarray type Rewards data
        """

        return self.standDims(np.asarray(self.rewards))

    def getDones(self):
        """get all Dones data as ndarray

        Returns:
            ndarray: ndarray type Dones data
        """

        return self.standDims(np.asarray(self.dones))

    def standDims(self, data):
        """standalize data's dimension

        Args:
            data (list): data list

        Returns:
            ndarray: ndarra type data
        """
        # standarlize data's dimension
        if np.ndim(data) > 2:
            return np.squeeze(data, axis=1)
        elif np.ndim(data) < 2:
            return np.expand_dims(data, axis=1)
        else:
            return np.asarray(data)

    def saveMems(self, state, actorProb, action, reward, done):
        """save memories

        Args:
            state (_type_): sates
            actorProb (_type_): actor predict result
            action (_type_): actor choosed action
            reward (_type_): reward
            done (function): done
        """
        self.states.append(state)
        self.actorProbs.append(actorProb)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)
        self.memNum += 1