import wandb
import time
from torch.utils.tensorboard import SummaryWriter


total_rounds = {"Free": 0, "Go": 0, "Attack": 0}
win_rounds = {"Free": 0, "Go": 0, "Attack": 0}


# class for wandb recording
class WandbRecorder:
    def __init__(self, game_name: str, game_type: str, run_name: str, _args) -> None:
        # init wandb
        self.game_name = game_name
        self.game_type = game_type
        self._args = _args
        self.run_name = run_name
        if self._args.wandb_track:
            wandb.init(
                project=self.game_name,
                entity=self._args.wandb_entity,
                sync_tensorboard=True,
                config=vars(self._args),
                name=self.run_name,
                monitor_gym=True,
                save_code=True,
            )
        self.writer = SummaryWriter(f"runs/{self.run_name}")
        self.writer.add_text(
            "hyperparameters",
            "|param|value|\n|-|-|\n%s"
            % ("\n".join([f"|{key}|{value}|" for key, value in vars(self._args).items()])),
        )

    def add_target_scalar(
        self,
        target_name,
        thisT,
        v_loss,
        dis_pg_loss,
        con_pg_loss,
        loss,
        entropy_loss,
        target_reward_mean,
        target_steps,
    ):
        # fmt:off
        self.writer.add_scalar(
            f"Target{target_name}/value_loss", v_loss.item(), target_steps[thisT]
        )
        self.writer.add_scalar(
            f"Target{target_name}/dis_policy_loss", dis_pg_loss.item(), target_steps[thisT]
        )
        self.writer.add_scalar(
            f"Target{target_name}/con_policy_loss", con_pg_loss.item(), target_steps[thisT]
        )
        self.writer.add_scalar(
            f"Target{target_name}/total_loss", loss.item(), target_steps[thisT]
        )
        self.writer.add_scalar(
            f"Target{target_name}/entropy_loss", entropy_loss.item(), target_steps[thisT]
        )
        self.writer.add_scalar(
            f"Target{target_name}/Reward", target_reward_mean, target_steps[thisT]
        )
        self.writer.add_scalar(
            f"Target{target_name}/WinRatio", win_rounds[target_name] / total_rounds[target_name], target_steps[thisT],
        )
        # fmt:on

    def add_global_scalar(
        self,
        total_reward_mean,
        learning_rate,
        total_steps,
    ):
        self.writer.add_scalar("GlobalCharts/TotalRewardMean", total_reward_mean, total_steps)
        self.writer.add_scalar("GlobalCharts/learning_rate", learning_rate, total_steps)
    def add_win_ratio(self, target_name, target_steps):
        self.writer.add_scalar(
            f"Target{target_name}/WinRatio", win_rounds[target_name] / total_rounds[target_name], target_steps,
        )