import matplotlib.pyplot as plt

DarkBlue = "#011627"
DarkWhite = "#c9d2df"


class GAILHistory(object):
    def __init__(self):
        self.meanRewards = []
        self.discrimLosses = []
        self.actorLosses = []
        self.criticLosses = []
        self.demoAccs = []
        self.agentAccs = []
        self.averageEntropys = []
        self.discrimRewards = []

    def saveHis(
        self, rewards, dLosses, aLosses, cLosses, demoAcc, agentAcc, averageEntropy, discrimReward
    ):
        self.meanRewards.extend([rewards])
        self.discrimLosses.extend(dLosses)
        self.actorLosses.extend(aLosses)
        self.criticLosses.extend(cLosses)
        self.demoAccs.extend([demoAcc])
        self.agentAccs.extend([agentAcc])
        self.averageEntropys.extend([averageEntropy])
        self.discrimRewards.extend(discrimReward)

    def drawHis(self):
        def setSubFig(subFig, data, title):
            subFig.set_facecolor(DarkBlue)
            subFig.tick_params(colors=DarkWhite)
            subFig.spines["top"].set_color(DarkWhite)
            subFig.spines["bottom"].set_color(DarkWhite)
            subFig.spines["left"].set_color(DarkWhite)
            subFig.spines["right"].set_color(DarkWhite)
            subFig.plot(range(len(data)), data, color=DarkWhite, label=title)
            subFig.set_title(title, color=DarkWhite)

        fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6), (ax7, ax8)) = plt.subplots(
            4, 2, figsize=(21, 13), facecolor=DarkBlue
        )
        plt.tick_params()
        setSubFig(ax1, self.meanRewards, "meanRewards")
        setSubFig(ax2, self.discrimLosses, "discrimLosses")
        setSubFig(ax3, self.demoAccs, "demoAccs")
        setSubFig(ax4, self.actorLosses, "actorLosses")
        setSubFig(ax5, self.agentAccs, "agentAccs")
        setSubFig(ax6, self.criticLosses, "criticLosses")
        setSubFig(ax7, self.averageEntropys, "averageEntropys")
        setSubFig(ax8, self.discrimRewards, "discrimRewards")
        plt.show()