{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Action, 1 continuous ctrl 2.1\n",
      "Action, 0 continuous ctrl -1.1\n"
     ]
    }
   ],
   "source": [
    "import gym\n",
    "from gym.spaces import Dict, Discrete, Box, Tuple\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "class SampleGym(gym.Env):\n",
    "    def __init__(self, config={}):\n",
    "        self.config = config\n",
    "        self.action_space = Tuple((Discrete(2), Box(-10, 10, (2,))))\n",
    "        self.observation_space = Box(-10, 10, (2, 2))\n",
    "        self.p_done = config.get(\"p_done\", 0.1)\n",
    "\n",
    "    def reset(self):\n",
    "        return self.observation_space.sample()\n",
    "\n",
    "    def step(self, action):\n",
    "        chosen_action = action[0]\n",
    "        cnt_control = action[1][chosen_action]\n",
    "\n",
    "        if chosen_action == 0:\n",
    "            reward = cnt_control\n",
    "        else:\n",
    "            reward = -cnt_control - 1\n",
    "\n",
    "        print(f\"Action, {chosen_action} continuous ctrl {cnt_control}\")\n",
    "        return (\n",
    "            self.observation_space.sample(),\n",
    "            reward,\n",
    "            bool(np.random.choice([True, False], p=[self.p_done, 1.0 - self.p_done])),\n",
    "            {},\n",
    "        )\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    env = SampleGym()\n",
    "    env.reset()\n",
    "    env.step((1, [-1, 2.1]))  # should say use action 1 with 2.1\n",
    "    env.step((0, [-1.1, 2.1]))  # should say use action 0 with -1.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mlagents_envs.environment import UnityEnvironment\n",
    "from gym_unity.envs import UnityToGymWrapper\n",
    "import numpy as np\n",
    "\n",
    "ENV_PATH = \"../Build-ParallelEnv/Aimbot-ParallelEnv\"\n",
    "WORKER_ID = 1\n",
    "BASE_PORT = 2002\n",
    "\n",
    "env = UnityEnvironment(\n",
    "    file_name=ENV_PATH,\n",
    "    seed=1,\n",
    "    side_channels=[],\n",
    "    worker_id=WORKER_ID,\n",
    "    base_port=BASE_PORT,\n",
    ")\n",
    "\n",
    "trackedAgent = 0\n",
    "env.reset()\n",
    "BEHA_SPECS = env.behavior_specs\n",
    "BEHA_NAME = list(BEHA_SPECS)[0]\n",
    "SPEC = BEHA_SPECS[BEHA_NAME]\n",
    "print(SPEC)\n",
    "\n",
    "decisionSteps, terminalSteps = env.get_steps(BEHA_NAME)\n",
    "\n",
    "if trackedAgent in decisionSteps:  # ゲーム終了していない場合、環境状態がdecision_stepsに保存される\n",
    "    nextState = decisionSteps[trackedAgent].obs[0]\n",
    "    reward = decisionSteps[trackedAgent].reward\n",
    "    done = False\n",
    "if trackedAgent in terminalSteps:  # ゲーム終了した場合、環境状態がterminal_stepsに保存される\n",
    "    nextState = terminalSteps[trackedAgent].obs[0]\n",
    "    reward = terminalSteps[trackedAgent].reward\n",
    "    done = True\n",
    "print(decisionSteps.agent_id)\n",
    "print(terminalSteps)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "decisionSteps.agent_id [1 2 5 7]\n",
      "decisionSteps.agent_id_to_index {1: 0, 2: 1, 5: 2, 7: 3}\n",
      "decisionSteps.reward [0. 0. 0. 0.]\n",
      "decisionSteps.action_mask [array([[False, False, False],\n",
      "       [False, False, False],\n",
      "       [False, False, False],\n",
      "       [False, False, False]]), array([[False, False, False],\n",
      "       [False, False, False],\n",
      "       [False, False, False],\n",
      "       [False, False, False]]), array([[False, False],\n",
      "       [False, False],\n",
      "       [False, False],\n",
      "       [False, False]])]\n",
      "decisionSteps.obs [  0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.        -15.994009    1.        -26.322788    1.\n",
      "   1.          1.          1.          1.          1.          2.\n",
      "   1.          1.          1.          1.          1.          1.\n",
      "   1.          1.3519633   1.6946528   2.3051548   3.673389    9.067246\n",
      "  17.521473   21.727095   22.753294   24.167128   25.905216   18.35725\n",
      "  21.02278    21.053417    0.       ]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'decisionSteps.obs [array([[-15.994009 ,   1.       , -26.322788 ,   1.       ,   1.       ,\\n          1.       ,   1.       ,   1.       ,   1.       ,   2.       ,\\n          1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\\n          1.       ,   1.       ,   1.3519633,   1.6946528,   2.3051548,\\n          3.673389 ,   9.067246 ,  17.521473 ,  21.727095 ,  22.753294 ,\\n         24.167128 ,  25.905216 ,  18.35725  ,  21.02278  ,  21.053417 ,\\n          0.       ],\\n       [ -1.8809433,   1.       , -25.66834  ,   1.       ,   2.       ,\\n          1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\\n          1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\\n          1.       ,   1.       ,  16.768637 ,  23.414627 ,  22.04486  ,\\n         21.050663 ,  20.486784 ,  20.486784 ,  21.050665 ,  15.049731 ,\\n         11.578419 ,   9.695194 ,  20.398016 ,  20.368341 ,  20.398016 ,\\n...\\n         20.551746 ,  20.00118  ,  20.001116 ,  20.551594 ,  21.5222   ,\\n         17.707508 ,  14.86889  ,  19.914494 ,  19.885508 ,  19.914463 ,\\n          0.       ]], dtype=float32)]'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"decisionSteps.agent_id\",decisionSteps.agent_id)\n",
    "# decisionSteps.agent_id [1 2 5 7]\n",
    "print(\"decisionSteps.agent_id_to_index\",decisionSteps.agent_id_to_index)\n",
    "# decisionSteps.agent_id_to_index {1: 0, 2: 1, 5: 2, 7: 3}\n",
    "print(\"decisionSteps.reward\",decisionSteps.reward)\n",
    "# decisionSteps.reward [0. 0. 0. 0.]\n",
    "print(\"decisionSteps.action_mask\",decisionSteps.action_mask)\n",
    "'''\n",
    "decisionSteps.action_mask [array([[False, False, False],\n",
    "       [False, False, False],\n",
    "       [False, False, False],\n",
    "       [False, False, False]]), array([[False, False, False],\n",
    "       [False, False, False],\n",
    "       [False, False, False],\n",
    "       [False, False, False]]), array([[False, False],\n",
    "       [False, False],\n",
    "       [False, False],\n",
    "       [False, False]])]\n",
    "'''\n",
    "print(\"decisionSteps.obs\", decisionSteps.obs[0][0])\n",
    "'''decisionSteps.obs [array([[-15.994009 ,   1.       , -26.322788 ,   1.       ,   1.       ,\n",
    "          1.       ,   1.       ,   1.       ,   1.       ,   2.       ,\n",
    "          1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
    "          1.       ,   1.       ,   1.3519633,   1.6946528,   2.3051548,\n",
    "          3.673389 ,   9.067246 ,  17.521473 ,  21.727095 ,  22.753294 ,\n",
    "         24.167128 ,  25.905216 ,  18.35725  ,  21.02278  ,  21.053417 ,\n",
    "          0.       ],\n",
    "       [ -1.8809433,   1.       , -25.66834  ,   1.       ,   2.       ,\n",
    "          1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
    "          1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
    "          1.       ,   1.       ,  16.768637 ,  23.414627 ,  22.04486  ,\n",
    "         21.050663 ,  20.486784 ,  20.486784 ,  21.050665 ,  15.049731 ,\n",
    "         11.578419 ,   9.695194 ,  20.398016 ,  20.368341 ,  20.398016 ,\n",
    "...\n",
    "         20.551746 ,  20.00118  ,  20.001116 ,  20.551594 ,  21.5222   ,\n",
    "         17.707508 ,  14.86889  ,  19.914494 ,  19.885508 ,  19.914463 ,\n",
    "          0.       ]], dtype=float32)]'''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from AimbotEnv import Aimbot\n",
    "\n",
    "ENV_PATH = \"../Build-ParallelEnv/Aimbot-ParallelEnv\"\n",
    "WORKER_ID = 1\n",
    "BASE_PORT = 2002\n",
    "\n",
    "env = Aimbot(envPath=ENV_PATH,workerID= WORKER_ID,basePort= BASE_PORT)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[  0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       , -15.994009 ,   1.       , -26.322788 ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           2.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.3519633,   1.6946528,\n",
       "           2.3051548,   3.673389 ,   9.067246 ,  17.521473 ,  21.727095 ,\n",
       "          22.753294 ,  24.167128 ,  25.905216 ,  18.35725  ,  21.02278  ,\n",
       "          21.053417 ,   0.       , -15.994003 ,   1.       , -26.322784 ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.3519667,\n",
       "           1.6946585,   2.3051722,   3.6734192,   9.067533 ,  21.145092 ,\n",
       "          21.727148 ,  22.753365 ,  24.167217 ,  25.905317 ,  18.358263 ,\n",
       "          21.022812 ,  21.053455 ,   0.       ],\n",
       "        [  0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,  -1.8809433,   1.       , -25.66834  ,   1.       ,\n",
       "           2.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,  16.768637 ,  23.414627 ,\n",
       "          22.04486  ,  21.050663 ,  20.486784 ,  20.486784 ,  21.050665 ,\n",
       "          15.049731 ,  11.578419 ,   9.695194 ,  20.398016 ,  20.368341 ,\n",
       "          20.398016 ,   0.       ,  -1.8809433,   1.       , -25.66834  ,\n",
       "           1.       ,   1.       ,   2.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   2.       ,\n",
       "           2.       ,   1.       ,   1.       ,   1.       ,  25.098585 ,\n",
       "          15.749494 ,  22.044899 ,  21.050697 ,  20.486813 ,  20.486813 ,\n",
       "          21.050694 ,  15.049746 ,   3.872317 ,   3.789325 ,  20.398046 ,\n",
       "          20.368372 ,  20.398046 ,   0.       ],\n",
       "        [  0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       , -13.672583 ,   1.       , -26.479263 ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   5.3249803,   6.401276 ,\n",
       "           8.374101 ,  12.8657875,  21.302414 ,  21.30242  ,  21.888742 ,\n",
       "          22.92251  ,  24.346794 ,  26.09773  ,  21.210114 ,  21.179258 ,\n",
       "          21.210117 ,   0.       , -13.672583 ,   1.       , -26.479263 ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   2.       ,   1.       ,   1.       ,\n",
       "           2.       ,   1.       ,   1.       ,   2.       ,   5.3249855,\n",
       "           6.4012837,   8.374114 ,  12.865807 ,  21.302446 ,  21.30245  ,\n",
       "          16.168503 ,  22.922543 ,  24.346823 ,   7.1110754,  21.210148 ,\n",
       "          21.17929  ,  12.495141 ,   0.       ],\n",
       "        [  0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,   0.       ,   0.       ,   0.       ,   0.       ,\n",
       "           0.       ,  -4.9038744,   1.       , -25.185507 ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,  20.33171  ,  22.859762 ,\n",
       "          21.522427 ,  20.551746 ,  20.00118  ,  20.001116 ,  20.551594 ,\n",
       "          21.5222   ,  17.707508 ,  14.86889  ,  19.914494 ,  19.885508 ,\n",
       "          19.914463 ,   0.       ,  -4.9038773,   1.       , -25.185507 ,\n",
       "           1.       ,   2.       ,   1.       ,   2.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   2.       ,   1.       ,\n",
       "           1.       ,   1.       ,   1.       ,   1.       ,  15.905993 ,\n",
       "          22.85977  ,  11.566693 ,  20.551773 ,  20.00121  ,  20.001146 ,\n",
       "          20.551619 ,   7.135157 ,  17.707582 ,  14.868943 ,  19.914528 ,\n",
       "          19.88554  ,  19.914494 ,   0.       ]], dtype=float32),\n",
       " [[-0.05], [-0.05], [-0.05], [-0.05]],\n",
       " [[False], [False], [False], [False]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "env.unity_observation_shape\n",
    "(128, 4) + env.unity_observation_shape\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1, 2, 3],\n",
      "        [1, 2, 3],\n",
      "        [1, 2, 3],\n",
      "        [1, 2, 3]], device='cuda:0')\n",
      "tensor([[1],\n",
      "        [2],\n",
      "        [3],\n",
      "        [4]], device='cuda:0')\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2, 3, 1],\n",
       "        [1, 2, 3, 2],\n",
       "        [1, 2, 3, 3],\n",
       "        [1, 2, 3, 4]], device='cuda:0')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "aa = torch.tensor([[1,2,3],[1,2,3],[1,2,3],[1,2,3]]).to(\"cuda:0\")\n",
    "bb = torch.tensor([[1],[2],[3],[4]]).to(\"cuda:0\")\n",
    "print(aa)\n",
    "print(bb)\n",
    "torch.cat([aa,bb],axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "Can't get attribute 'PPOAgent' on <module '__main__'>",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_31348\\1930153251.py\u001b[0m in \u001b[0;36m<cell line: 2>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmymodel\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../PPO-Model/SmallArea-256-128-hybrid.pt\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      3\u001b[0m \u001b[0mmymodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0meval\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36mload\u001b[1;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[0;32m    710\u001b[0m                     \u001b[0mopened_file\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mseek\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0morig_position\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    711\u001b[0m                     \u001b[1;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 712\u001b[1;33m                 \u001b[1;32mreturn\u001b[0m \u001b[0m_load\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_zipfile\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    713\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0m_legacy_load\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopened_file\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmap_location\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpickle_module\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    714\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36m_load\u001b[1;34m(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)\u001b[0m\n\u001b[0;32m   1047\u001b[0m     \u001b[0munpickler\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mUnpicklerWrapper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdata_file\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mpickle_load_args\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1048\u001b[0m     \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpersistent_load\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mpersistent_load\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1049\u001b[1;33m     \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0munpickler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1050\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1051\u001b[0m     \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_utils\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_validate_loaded_sparse_tensors\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mc:\\Users\\UCUNI\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\serialization.py\u001b[0m in \u001b[0;36mfind_class\u001b[1;34m(self, mod_name, name)\u001b[0m\n\u001b[0;32m   1040\u001b[0m                     \u001b[1;32mpass\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1041\u001b[0m             \u001b[0mmod_name\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mload_module_mapping\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmod_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1042\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfind_class\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmod_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mname\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1043\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1044\u001b[0m     \u001b[1;31m# Load the data (which may in turn use `persistent_load` to load tensors)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: Can't get attribute 'PPOAgent' on <module '__main__'>"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "def layer_init(layer, std=np.sqrt(2), bias_const=0.0):\n",
    "    torch.nn.init.orthogonal_(layer.weight, std)\n",
    "    torch.nn.init.constant_(layer.bias, bias_const)\n",
    "    return layer\n",
    "\n",
    "class PPOAgent(nn.Module):\n",
    "    def __init__(self, env: Aimbot):\n",
    "        super(PPOAgent, self).__init__()\n",
    "        self.discrete_size = env.unity_discrete_size\n",
    "        self.discrete_shape = list(env.unity_discrete_branches)\n",
    "        self.continuous_size = env.unity_continuous_size\n",
    "\n",
    "        self.network = nn.Sequential(\n",
    "            layer_init(nn.Linear(np.array(env.unity_observation_shape).prod(), 256)),\n",
    "            nn.ReLU(),\n",
    "            layer_init(nn.Linear(256, 128)),\n",
    "            nn.ReLU(),\n",
    "        )\n",
    "        self.actor_dis = layer_init(nn.Linear(128, self.discrete_size), std=0.01)\n",
    "        self.actor_mean = layer_init(nn.Linear(128, self.continuous_size), std=0.01)\n",
    "        self.actor_logstd = nn.Parameter(torch.zeros(1, self.continuous_size))\n",
    "        self.critic = layer_init(nn.Linear(128, 1), std=1)\n",
    "\n",
    "    def get_value(self, state: torch.Tensor):\n",
    "        return self.critic(self.network(state))\n",
    "\n",
    "    def get_actions_value(self, state: torch.Tensor, actions=None):\n",
    "        hidden = self.network(state)\n",
    "        # discrete\n",
    "        dis_logits = self.actor_dis(hidden)\n",
    "        split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)\n",
    "        multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]\n",
    "        # continuous\n",
    "        actions_mean = self.actor_mean(hidden)\n",
    "        action_logstd = self.actor_logstd.expand_as(actions_mean)\n",
    "        action_std = torch.exp(action_logstd)\n",
    "        con_probs = Normal(actions_mean, action_std)\n",
    "\n",
    "        if actions is None:\n",
    "            disAct = torch.stack([ctgr.sample() for ctgr in multi_categoricals])\n",
    "            conAct = con_probs.sample()\n",
    "            actions = torch.cat([disAct.T, conAct], dim=1)\n",
    "        else:\n",
    "            disAct = actions[:, 0 : env.unity_discrete_type].T\n",
    "            conAct = actions[:, env.unity_discrete_type :]\n",
    "        dis_log_prob = torch.stack(\n",
    "            [ctgr.log_prob(act) for act, ctgr in zip(disAct, multi_categoricals)]\n",
    "        )\n",
    "        dis_entropy = torch.stack([ctgr.entropy() for ctgr in multi_categoricals])\n",
    "        return (\n",
    "            actions,\n",
    "            dis_log_prob.sum(0),\n",
    "            dis_entropy.sum(0),\n",
    "            con_probs.log_prob(conAct).sum(1),\n",
    "            con_probs.entropy().sum(1),\n",
    "            self.critic(hidden),\n",
    "        )\n",
    "\n",
    "\n",
    "mymodel = torch.load(\"../PPO-Model/SmallArea-256-128-hybrid.pt\")\n",
    "mymodel.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "x = torch.randn(2, 3).to(\"cuda\")\n",
    "print(x)\n",
    "print(torch.cat((x, x, x), 0))\n",
    "print(torch.cat((x, x, x), 1))\n",
    "\n",
    "aa = torch.empty(0).to(\"cuda\")\n",
    "torch.cat([aa,x])\n",
    "bb = [[]]*2\n",
    "print(bb)\n",
    "bb.append(x.to(\"cpu\").tolist())\n",
    "bb.append(x.to(\"cpu\").tolist())\n",
    "print(bb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-1.1090,  0.4686,  0.6883],\n",
      "        [-0.1862, -0.3943, -0.0202],\n",
      "        [ 0.1436, -0.9444, -1.2079],\n",
      "        [-2.9434, -2.5989, -0.6653],\n",
      "        [ 0.4668,  0.8548, -0.4641],\n",
      "        [-0.3956, -0.2832, -0.1889],\n",
      "        [-0.2801, -0.2092,  1.7254],\n",
      "        [ 2.7938, -0.7742,  0.7053]], device='cuda:0')\n",
      "(8, 0)\n",
      "---\n",
      "[[array([-1.1090169,  0.4685607,  0.6883437], dtype=float32)], [array([-0.1861974 , -0.39429024, -0.02016036], dtype=float32)], [array([ 0.14360362, -0.9443668 , -1.2079065 ], dtype=float32)], [array([-2.9433894 , -2.598913  , -0.66532046], dtype=float32)], [array([ 0.46684313,  0.8547877 , -0.46408093], dtype=float32)], [array([-0.39563984, -0.2831819 , -0.18891   ], dtype=float32)], [array([-0.28008553, -0.20918302,  1.7253567 ], dtype=float32)], [array([ 2.7938051, -0.7742478,  0.705279 ], dtype=float32)]]\n",
      "[[array([-1.1090169,  0.4685607,  0.6883437], dtype=float32)], [], [array([ 0.14360362, -0.9443668 , -1.2079065 ], dtype=float32)], [array([-2.9433894 , -2.598913  , -0.66532046], dtype=float32)], [array([ 0.46684313,  0.8547877 , -0.46408093], dtype=float32)], [array([-0.39563984, -0.2831819 , -0.18891   ], dtype=float32)], [array([-0.28008553, -0.20918302,  1.7253567 ], dtype=float32)], [array([ 2.7938051, -0.7742478,  0.705279 ], dtype=float32)]]\n",
      "---\n",
      "[array([-1.1090169,  0.4685607,  0.6883437], dtype=float32), array([-1.1090169,  0.4685607,  0.6883437], dtype=float32)]\n",
      "vvv tensor([[-1.1090,  0.4686,  0.6883],\n",
      "        [-1.1090,  0.4686,  0.6883]], device='cuda:0')\n",
      "tensor([[-1.1090,  0.4686,  0.6883],\n",
      "        [-1.1090,  0.4686,  0.6883]], device='cuda:0')\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 64,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "agent_num = 8\n",
    "ob_buffer = [[]for i in range(agent_num)]\n",
    "obs =  torch.randn(8, 3).to(\"cuda\")\n",
    "print(obs)\n",
    "print(np.shape(np.array(ob_buffer)))\n",
    "print('---')\n",
    "obs_cpu = obs.to(\"cpu\").numpy()\n",
    "for i in range(agent_num):\n",
    "    ob_buffer[i].append(obs_cpu[i])\n",
    "print(ob_buffer)\n",
    "ob_buffer[1] = []\n",
    "print(ob_buffer)\n",
    "print('---')\n",
    "for i in range(agent_num):\n",
    "    ob_buffer[i].append(obs_cpu[i])\n",
    "print(ob_buffer[0])\n",
    "vvv = torch.tensor(ob_buffer[0]).to(\"cuda\")\n",
    "print(\"vvv\",vvv)\n",
    "empt = torch.tensor([]).to(\"cuda\")\n",
    "vvvv = torch.cat((empt,vvv),0)\n",
    "print(vvvv)\n",
    "vvvv.size()[0]>0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'Go': 1, 'Attack': 0, 'Free': 0}\n"
     ]
    }
   ],
   "source": [
    "Total = {\"Go\":0,\"Attack\":0,\"Free\":0}\n",
    "\n",
    "Total[\"Go\"] +=1\n",
    "print(Total)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "86e2db13b09bd6be22cb599ea60c1572b9ef36ebeaa27a4c8e961d6df315ac32"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}