MultiThread PPO First Commit

This commit is contained in:
2023-11-23 15:25:34 +09:00
parent 3bc5c30fd3
commit 2ea8a5f104
5 changed files with 296 additions and 180 deletions
+28 -169
View File
@@ -81,184 +81,43 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"import wandb\n",
"import time\n",
"import numpy as np\n",
"import random\n",
"import uuid\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"from AimbotEnv import Aimbot\n",
"from tqdm import tqdm\n",
"from torch.distributions.normal import Normal\n",
"from torch.distributions.categorical import Categorical\n",
"from distutils.util import strtobool\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"from mlagents_envs.environment import UnityEnvironment\n",
"from mlagents_envs.side_channel.side_channel import (\n",
" SideChannel,\n",
" IncomingMessage,\n",
" OutgoingMessage,\n",
")\n",
"from typing import List\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'aaa' object has no attribute 'outa'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[5], line 14\u001b[0m\n\u001b[0;32m 12\u001b[0m asd \u001b[39m=\u001b[39m aaa(outa, outb)\n\u001b[0;32m 13\u001b[0m asd\u001b[39m.\u001b[39mfunc()\n\u001b[1;32m---> 14\u001b[0m \u001b[39mprint\u001b[39m(asd\u001b[39m.\u001b[39;49mouta) \u001b[39m# 输出 100\u001b[39;00m\n",
"\u001b[1;31mAttributeError\u001b[0m: 'aaa' object has no attribute 'outa'"
]
}
],
"source": [
"class aaa():\n",
" def __init__(self, a, b):\n",
" self.a = a\n",
" self.b = b\n",
"\n",
" def func(self):\n",
" global outa\n",
" outa = 100\n",
"\n",
"outa = 1\n",
"outb = 2\n",
"asd = aaa(outa, outb)\n",
"asd.func()\n",
"print(asd.outa) # 输出 100"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"usage: ipykernel_launcher.py [-h] [--seed SEED]\n",
"ipykernel_launcher.py: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9003 --control=9001 --hb=9000 --Session.signature_scheme=\"hmac-sha256\" --Session.key=b\"46ef9317-59fb-4ab6-ae4e-6b35744fc423\" --shell=9002 --transport=\"tcp\" --iopub=9004 --f=c:\\Users\\UCUNI\\AppData\\Roaming\\jupyter\\runtime\\kernel-v2-311926K1uko38tdWb.json\n"
]
},
{
"ename": "SystemExit",
"evalue": "2",
"output_type": "error",
"traceback": [
"An exception has occurred, use %tb to see the full traceback.\n",
"\u001b[1;31mSystemExit\u001b[0m\u001b[1;31m:\u001b[0m 2\n"
]
}
],
"source": [
"import argparse\n",
"\n",
"def parse_args():\n",
" parser = argparse.ArgumentParser()\n",
" parser.add_argument(\"--seed\", type=int, default=11,\n",
" help=\"seed of the experiment\")\n",
" args = parser.parse_args()\n",
" return args\n",
"\n",
"arggg = parse_args()\n",
"print(type(arggg))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y=\"a;b;c\"\n",
"len(y.split(\";\"))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2]\n"
"0\n",
"i = 0\n",
"i = 1\n",
"i = 2\n",
"i = 3\n",
"i = 4\n",
"i = 5\n",
"i = 6\n",
"i = 7\n",
"i = 8\n",
"i = 9\n",
"10\n"
]
}
],
"source": [
"a = np.array([1,2,3,4])\n",
"print(a[[False,True,False,False]])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{1, 2, 3, 4}"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = {1,2,3}\n",
"a.add(4)\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([3, 4])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = np.array([[1,3],[2,4]])\n",
"a.max(axis=1)\n"
"import threading\n",
"\n",
"num = 0\n",
"\n",
"def print_numers():\n",
" global num\n",
" for i in range(10):\n",
" num +=1\n",
" print(\"i = \",i)\n",
"\n",
"thread = threading.Thread(target=print_numers)\n",
"\n",
"print(num)\n",
"thread.start()\n",
"thread.join()\n",
"print(num)"
]
}
],