1 Commits

Author SHA1 Message Date
Koha9 10a1663230 将Tensor改为tensor
Tensor与tensor的问题,规范化tensor使用。
2023-08-08 20:49:23 +09:00
6 changed files with 198 additions and 362 deletions
+10 -14
View File
@@ -11,7 +11,6 @@ from mlagents_envs.side_channel.side_channel import (
IncomingMessage,
OutgoingMessage,
)
from arguments import set_save_model
class Aimbot(gym.Env):
@@ -177,21 +176,18 @@ class AimbotSideChannel(SideChannel):
"Warning|Message1|Message2|Message3" or
"Error|Message1|Message2|Message3"
"""
this_message_Original = msg.read_string()
this_message = this_message_Original.split("|")
print(this_message)
if this_message[0] == "Warning":
if this_message[1] == "Result":
airecorder.total_rounds[this_message[2]] += 1
if this_message[3] == "Win":
airecorder.win_rounds[this_message[2]] += 1
this_message = msg.read_string()
this_result = this_message.split("|")
print(this_result)
if this_result[0] == "Warning":
if this_result[1] == "Result":
airecorder.total_rounds[this_result[2]] += 1
if this_result[3] == "Win":
airecorder.win_rounds[this_result[2]] += 1
# print(TotalRounds)
# print(WinRounds)
if this_message[1] == "Command":
set_save_model(True)
print("Command: " + this_message_Original)
elif this_message[0] == "Error":
print(this_message_Original)
elif this_result[0] == "Error":
print(this_message)
# # while Message type is Warning
# if(thisResult[0] == "Warning"):
# # while Message1 is result means one game is over
+114 -27
View File
@@ -81,43 +81,130 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import argparse\n",
"import wandb\n",
"import time\n",
"import numpy as np\n",
"import random\n",
"import uuid\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"from AimbotEnv import Aimbot\n",
"from tqdm import tqdm\n",
"from torch.distributions.normal import Normal\n",
"from torch.distributions.categorical import Categorical\n",
"from distutils.util import strtobool\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"from mlagents_envs.environment import UnityEnvironment\n",
"from mlagents_envs.side_channel.side_channel import (\n",
" SideChannel,\n",
" IncomingMessage,\n",
" OutgoingMessage,\n",
")\n",
"from typing import List\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"i = 0\n",
"i = 1\n",
"i = 2\n",
"i = 3\n",
"i = 4\n",
"i = 5\n",
"i = 6\n",
"i = 7\n",
"i = 8\n",
"i = 9\n",
"10\n"
"ename": "AttributeError",
"evalue": "'aaa' object has no attribute 'outa'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[5], line 14\u001b[0m\n\u001b[0;32m 12\u001b[0m asd \u001b[39m=\u001b[39m aaa(outa, outb)\n\u001b[0;32m 13\u001b[0m asd\u001b[39m.\u001b[39mfunc()\n\u001b[1;32m---> 14\u001b[0m \u001b[39mprint\u001b[39m(asd\u001b[39m.\u001b[39;49mouta) \u001b[39m# 输出 100\u001b[39;00m\n",
"\u001b[1;31mAttributeError\u001b[0m: 'aaa' object has no attribute 'outa'"
]
}
],
"source": [
"import threading\n",
"class aaa():\n",
" def __init__(self, a, b):\n",
" self.a = a\n",
" self.b = b\n",
"\n",
"num = 0\n",
" def func(self):\n",
" global outa\n",
" outa = 100\n",
"\n",
"def print_numers():\n",
" global num\n",
" for i in range(10):\n",
" num +=1\n",
" print(\"i = \",i)\n",
"outa = 1\n",
"outb = 2\n",
"asd = aaa(outa, outb)\n",
"asd.func()\n",
"print(asd.outa) # 输出 100"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"usage: ipykernel_launcher.py [-h] [--seed SEED]\n",
"ipykernel_launcher.py: error: unrecognized arguments: --ip=127.0.0.1 --stdin=9003 --control=9001 --hb=9000 --Session.signature_scheme=\"hmac-sha256\" --Session.key=b\"46ef9317-59fb-4ab6-ae4e-6b35744fc423\" --shell=9002 --transport=\"tcp\" --iopub=9004 --f=c:\\Users\\UCUNI\\AppData\\Roaming\\jupyter\\runtime\\kernel-v2-311926K1uko38tdWb.json\n"
]
},
{
"ename": "SystemExit",
"evalue": "2",
"output_type": "error",
"traceback": [
"An exception has occurred, use %tb to see the full traceback.\n",
"\u001b[1;31mSystemExit\u001b[0m\u001b[1;31m:\u001b[0m 2\n"
]
}
],
"source": [
"import argparse\n",
"\n",
"thread = threading.Thread(target=print_numers)\n",
"def parse_args():\n",
" parser = argparse.ArgumentParser()\n",
" parser.add_argument(\"--seed\", type=int, default=11,\n",
" help=\"seed of the experiment\")\n",
" args = parser.parse_args()\n",
" return args\n",
"\n",
"print(num)\n",
"thread.start()\n",
"thread.join()\n",
"print(num)"
"arggg = parse_args()\n",
"print(type(arggg))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.],\n",
" [0., 0., 0., 0.]])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"np.zeros((8, 4))"
]
}
],
+22 -21
View File
@@ -4,7 +4,6 @@ import random
import uuid
import torch
import atexit
import os
from aimbotEnv import Aimbot
from aimbotEnv import AimbotSideChannel
@@ -13,14 +12,13 @@ from airecorder import WandbRecorder
from aimemory import PPOMem
from aimemory import Targets
from arguments import parse_args
from arguments import set_save_model, is_save_model
import torch.optim as optim
# side channel uuid
SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e")
# tensorboard names
GAME_NAME = "Aimbot_Hybrid_Full_MNN_MultiLevel_V2"
GAME_TYPE = "GotoOnly-Level0123-new512Model"
GAME_NAME = "Aimbot_Hybrid_V3"
GAME_TYPE = "Mix_Verification"
if __name__ == "__main__":
args = parse_args()
@@ -49,8 +47,9 @@ if __name__ == "__main__":
# freeze
if args.freeze_viewnet:
# freeze the view network
print("FREEZE VIEW NETWORK is not compatible with Full MNN!")
raise NotImplementedError
for p in agent.viewNetwork.parameters():
p.requires_grad = False
print("VIEW NETWORK FREEZE")
print("Load Agent", args.load_dir)
print(agent.eval())
# optimizer
@@ -59,6 +58,16 @@ if __name__ == "__main__":
run_name = f"{GAME_TYPE}_{args.seed}_{int(time.time())}"
wdb_recorder = WandbRecorder(GAME_NAME, GAME_TYPE, run_name, args)
@atexit.register
def save_model():
# close env
env.close()
if args.save_model:
# save model while exit
save_dir = "../PPO-Model/" + run_name + "_last.pt"
torch.save(agent, save_dir)
print("save model to " + save_dir)
# start the game
total_update_step = args.target_num * args.total_timesteps // args.datasetSize
target_steps = [0 for i in range(args.target_num)]
@@ -214,16 +223,11 @@ if __name__ == "__main__":
)
# print cost time as seconds
print("cost time:", time.time() - start_time)
# New Record! or save model
if ((is_save_model() or TotalRewardMean > best_reward) and args.save_model):
# check saveDir is exist
saveDir = "../PPO-Model/" + run_name + "/"
if not os.path.isdir(saveDir):
os.mkdir(saveDir)
best_reward = TotalRewardMean
torch.save(agent, saveDir + str(TotalRewardMean) + ".pt")
print("Model Saved!")
set_save_model(False)
# New Record!
if TotalRewardMean > best_reward and args.save_model:
best_reward = target_reward_mean
saveDir = "../PPO-Model/" + run_name + "_" + str(TotalRewardMean) + ".pt"
torch.save(agent, saveDir)
else:
# train mode off
mean_reward_list = [] # for WANDB
@@ -246,10 +250,7 @@ if __name__ == "__main__":
TotalRewardMean = np.mean(mean_reward_list)
wdb_recorder.writer.add_scalar("GlobalCharts/TotalRewardMean", TotalRewardMean, total_steps)
saveDir = "../PPO-Model/" + run_name + "/"
if not os.path.isdir(saveDir):
os.mkdir(saveDir)
best_reward = target_reward_mean
torch.save(agent, saveDir + "_last.pt")
saveDir = "../PPO-Model/" + run_name + "_last.pt"
torch.save(agent, saveDir)
env.close()
wdb_recorder.writer.close()
+12 -17
View File
@@ -4,38 +4,41 @@ import uuid
from distutils.util import strtobool
DEFAULT_SEED = 9331
ENV_PATH = "../Build/3.5/Aimbot-ParallelEnv"
ENV_PATH = "../Build/3.1.6/Aimbot-ParallelEnv"
WAND_ENTITY = "koha9"
WORKER_ID = 1
BASE_PORT = 1000
# tensorboard names
GAME_NAME = "Aimbot_Target_Hybrid_PMNN_V3"
GAME_TYPE = "Mix_Verification"
# max round steps per agent is 2500/Decision_period, 25 seconds
TOTAL_STEPS = 3150000
BATCH_SIZE = 512
MAX_TRAINNING_DATASETS = 6000
DECISION_PERIOD = 1
LEARNING_RATE = 1.5e-4
LEARNING_RATE = 6.5e-4
GAMMA = 0.99
GAE_LAMBDA = 0.95
EPOCHS = 3
CLIP_COEF = 0.11
LOSS_COEF = [1.0, 1.0, 1.0, 1.0] # free go attack defence
POLICY_COEF = [0.8, 0.8, 0.8, 0.8]
POLICY_COEF = [1.0, 1.0, 1.0, 1.0]
ENTROPY_COEF = [0.05, 0.05, 0.05, 0.05]
CRITIC_COEF = [0.8, 0.8, 0.8, 0.8]
CRITIC_COEF = [0.5, 0.5, 0.5, 0.5]
TARGET_LEARNING_RATE = 1e-6
FREEZE_VIEW_NETWORK = False
BROADCASTREWARD = False
ANNEAL_LEARNING_RATE = True
CLIP_VLOSS = True
NORM_ADV = False
TRAIN = True
SAVE_MODEL = True
WANDB_TACK = True
SAVE_MODEL = False
WANDB_TACK = False
LOAD_DIR = None
# LOAD_DIR = "../PPO-Model/GotoOnly-Level0123_9331_1696965321/5.1035867.pt"
#LOAD_DIR = "../PPO-Model/PList_Go_LeakyReLU_9331_1677965178_bestGoto/PList_Go_LeakyReLU_9331_1677965178_10.709002.pt"
# Unity Environment Parameters
TARGET_STATE_SIZE = 6
@@ -50,16 +53,6 @@ TARGETNUM= 4
ENV_TIMELIMIT = 30
RESULT_BROADCAST_RATIO = 1/ENV_TIMELIMIT
save_model_this_episode = False
def is_save_model():
global save_model_this_episode
return save_model_this_episode
def set_save_model(save_model:bool):
print("set save model to ",save_model)
global save_model_this_episode
save_model_this_episode = save_model
def parse_args():
# fmt: off
# pytorch and environment parameters
@@ -104,6 +97,8 @@ def parse_args():
help="the number of steps to run in each environment per policy rollout")
parser.add_argument("--result-broadcast-ratio", type=float, default=RESULT_BROADCAST_RATIO,
help="broadcast result when win round is reached,r=result-broadcast-ratio*remainTime")
parser.add_argument("--broadCastEndReward", type=lambda x: bool(strtobool(x)), default=BROADCASTREWARD, nargs="?", const=True,
help="save model or not")
# target_learning_rate
parser.add_argument("--target-lr", type=float, default=TARGET_LEARNING_RATE,
help="target value of downscaling the learning rate")
@@ -1,255 +0,0 @@
import time
import numpy as np
import random
import uuid
import torch
import atexit
import os
from aimbotEnv import Aimbot
from aimbotEnv import AimbotSideChannel
from ppoagent import PPOAgent
from airecorder import WandbRecorder
from aimemory import PPOMem
from aimemory import Targets
from arguments import parse_args
from arguments import set_save_model, is_save_model
import torch.optim as optim
# side channel uuid
SIDE_CHANNEL_UUID = uuid.UUID("8bbfb62a-99b4-457c-879d-b78b69066b5e")
# tensorboard names
GAME_NAME = "Aimbot_Hybrid_Full_MNN_MultiLevel_V2"
GAME_TYPE = "GotoOnly-Level0123-new512Model"
if __name__ == "__main__":
args = parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
best_reward = -1
# Initialize environment agent optimizer
aimbot_side_channel = AimbotSideChannel(SIDE_CHANNEL_UUID)
env = Aimbot(
env_path=args.path,
worker_id=args.workerID,
base_port=args.baseport,
side_channels=[aimbot_side_channel])
if args.load_dir is None:
agent = PPOAgent(
env=env,
this_args=args,
device=device,
).to(device)
else:
agent = torch.load(args.load_dir)
# freeze
if args.freeze_viewnet:
# freeze the view network
print("FREEZE VIEW NETWORK is not compatible with Full MNN!")
raise NotImplementedError
print("Load Agent", args.load_dir)
print(agent.eval())
# optimizer
optimizer = optim.Adam(agent.parameters(), lr=args.lr, eps=1e-5)
# Tensorboard and WandB Recorder
run_name = f"{GAME_TYPE}_{args.seed}_{int(time.time())}"
wdb_recorder = WandbRecorder(GAME_NAME, GAME_TYPE, run_name, args)
# start the game
total_update_step = args.target_num * args.total_timesteps // args.datasetSize
target_steps = [0 for i in range(args.target_num)]
start_time = time.time()
state, _, done = env.reset()
# initialize AI memories
ppo_memories = PPOMem(
args=args,
unity_agent_num=env.unity_agent_num,
device=device,
)
# MAIN LOOP: run agent in environment
for total_steps in range(total_update_step):
# discount learning rate, while step == total_update_step lr will be 0
if args.annealLR:
final_lr_ratio = args.target_lr / args.lr
frac = 1.0 - ((total_steps + 1.0) / total_update_step)
lr_now = frac * args.lr
optimizer.param_groups[0]["lr"] = lr_now
else:
lr_now = args.lr
# episode start show learning rate
print("new episode", total_steps, "learning rate = ", lr_now)
step = 0
training = False
train_queue = []
last_reward = [0. for i in range(env.unity_agent_num)]
# MAIN LOOP: run agent in environment
while True:
# Target Type(state[0][0]) is stay(4),use all zero action
if state[0][0] == 4:
next_state, reward, next_done = env.step(env.all_zero_action)
state, done = next_state, next_done
continue
# On decision point, and Target Type(state[0][0]) is not stay(4) choose action by agent
if step % args.decision_period == 0:
step += 1
# Choose action by agent
with torch.no_grad():
# predict actions
action, dis_logprob, _, con_logprob, _, value = agent.get_actions_value(
torch.tensor(state,dtype=torch.float32).to(device)
)
value = value.flatten()
# variable from GPU to CPU
action_cpu = action.cpu().numpy()
dis_logprob_cpu = dis_logprob.cpu().numpy()
con_logprob_cpu = con_logprob.cpu().numpy()
value_cpu = value.cpu().numpy()
# Environment step
next_state, reward, next_done = env.step(action_cpu)
# save memories
if args.train:
ppo_memories.save_memories(
now_step=step,
agent=agent,
state=state,
action_cpu=action_cpu,
dis_logprob_cpu=dis_logprob_cpu,
con_logprob_cpu=con_logprob_cpu,
reward=reward,
done=done,
value_cpu=value_cpu,
last_reward=last_reward,
next_done=next_done,
next_state=next_state,
)
# check if any training dataset is full and ready to train
for i in range(args.target_num):
if ppo_memories.obs[i].size()[0] >= args.datasetSize:
# start train NN
train_queue.append(i)
if len(train_queue) > 0:
# break while loop and start train
break
# update state
state, done = next_state, next_done
else:
step += 1
# skip this step use last predict action
next_state, reward, next_done = env.step(action_cpu)
# save memories
if args.train:
ppo_memories.save_memories(
now_step=step,
agent=agent,
state=state,
action_cpu=action_cpu,
dis_logprob_cpu=dis_logprob_cpu,
con_logprob_cpu=con_logprob_cpu,
reward=reward,
done=done,
value_cpu=value_cpu,
last_reward=last_reward,
next_done=next_done,
next_state=next_state,
)
# update state
state = next_state
last_reward = reward
if args.train:
# train mode on
mean_reward_list = [] # for WANDB
# loop all training queue
for this_train_ind in train_queue:
# start time
start_time = time.time()
target_steps[this_train_ind] += 1
# train agent
(
v_loss,
dis_pg_loss,
con_pg_loss,
loss,
entropy_loss
) = agent.train_net(
this_train_ind=this_train_ind,
ppo_memories=ppo_memories,
optimizer=optimizer
)
# record mean reward before clear history
print("done")
target_reward_mean = np.mean(ppo_memories.rewards[this_train_ind].to("cpu").detach().numpy().copy())
mean_reward_list.append(target_reward_mean)
targetName = Targets(this_train_ind).name
# clear this target training set buffer
ppo_memories.clear_training_datasets(this_train_ind)
# record rewards for plotting purposes
wdb_recorder.add_target_scalar(
targetName,
this_train_ind,
v_loss,
dis_pg_loss,
con_pg_loss,
loss,
entropy_loss,
target_reward_mean,
target_steps,
)
print(f"episode over Target{targetName} mean reward:", target_reward_mean)
TotalRewardMean = np.mean(mean_reward_list)
wdb_recorder.add_global_scalar(
TotalRewardMean,
optimizer.param_groups[0]["lr"],
total_steps,
)
# print cost time as seconds
print("cost time:", time.time() - start_time)
# New Record! or save model
if ((is_save_model() or TotalRewardMean > best_reward) and args.save_model):
# check saveDir is exist
saveDir = "../PPO-Model/" + run_name + "/"
if not os.path.isdir(saveDir):
os.mkdir(saveDir)
best_reward = TotalRewardMean
torch.save(agent, saveDir + str(TotalRewardMean) + ".pt")
print("Model Saved!")
set_save_model(False)
else:
# train mode off
mean_reward_list = [] # for WANDB
# while not in training mode, clear the buffer
for this_train_ind in train_queue:
target_steps[this_train_ind] += 1
targetName = Targets(this_train_ind).name
target_reward_mean = np.mean(ppo_memories.rewards[this_train_ind].to("cpu").detach().numpy().copy())
mean_reward_list.append(target_reward_mean)
print(target_steps[this_train_ind])
# clear this target training set buffer
ppo_memories.clear_training_datasets(this_train_ind)
# record rewards for plotting purposes
wdb_recorder.writer.add_scalar(f"Target{targetName}/Reward", target_reward_mean,
target_steps[this_train_ind])
wdb_recorder.add_win_ratio(targetName, target_steps[this_train_ind])
print(f"episode over Target{targetName} mean reward:", target_reward_mean)
TotalRewardMean = np.mean(mean_reward_list)
wdb_recorder.writer.add_scalar("GlobalCharts/TotalRewardMean", TotalRewardMean, total_steps)
saveDir = "../PPO-Model/" + run_name + "/"
if not os.path.isdir(saveDir):
os.mkdir(saveDir)
best_reward = target_reward_mean
torch.save(agent, saveDir + "_last.pt")
env.close()
wdb_recorder.writer.close()
+39 -27
View File
@@ -8,8 +8,6 @@ from aimbotEnv import Aimbot
from torch.distributions.normal import Normal
from torch.distributions.categorical import Categorical
firstLayerNum = 512
secondLayerNum = 128
def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
nn.init.orthogonal_(layer.weight, std)
@@ -48,70 +46,84 @@ class PPOAgent(nn.Module):
self.discrete_shape = list(env.unity_discrete_branches)
self.continuous_size = env.unity_continuous_size
self.hidden_networks = nn.ModuleList(
self.view_network = nn.Sequential(layer_init(nn.Linear(self.ray_state_size, 200)), nn.LeakyReLU())
self.target_networks = nn.ModuleList(
[
nn.Sequential(
layer_init(nn.Linear(self.state_size, firstLayerNum)),
nn.LeakyReLU(),
layer_init(nn.Linear(firstLayerNum, secondLayerNum)),
nn.LeakyReLU(),
)
nn.Sequential(layer_init(nn.Linear(self.state_size_without_ray, 100)), nn.LeakyReLU())
for i in range(self.target_num)
]
)
self.middle_networks = nn.ModuleList(
[
nn.Sequential(layer_init(nn.Linear(300, 200)), nn.LeakyReLU())
for i in range(self.target_num)
]
)
self.actor_dis = nn.ModuleList(
[layer_init(nn.Linear(secondLayerNum, self.discrete_size), std=0.5) for i in range(self.target_num)]
[layer_init(nn.Linear(200, self.discrete_size), std=0.5) for i in range(self.target_num)]
)
self.actor_mean = nn.ModuleList(
[layer_init(nn.Linear(secondLayerNum, self.continuous_size), std=0) for i in range(self.target_num)]
[layer_init(nn.Linear(200, self.continuous_size), std=0.5) for i in range(self.target_num)]
)
self.actor_logstd = nn.ParameterList(
[nn.Parameter(torch.zeros(1, self.continuous_size)) for i in range(self.target_num)]
)
) # nn.Parameter(torch.zeros(1, self.continuous_size))
self.critic = nn.ModuleList(
[layer_init(nn.Linear(secondLayerNum, 1), std=0) for i in range(self.target_num)]
[layer_init(nn.Linear(200, 1), std=1) for i in range(self.target_num)]
)
def get_value(self, state: torch.Tensor):
# get critic value
# state.size()[0] is batch_size
target = state[:, 0].to(torch.int32) # int
hidden_output = torch.stack(
[self.hidden_networks[target[i]](state[i]) for i in range(state.size()[0])]
this_state_num = target.size()[0]
view_input = state[:, -self.ray_state_size:] # all ray input
target_input = state[:, : self.state_size_without_ray]
view_layer = self.view_network(view_input)
target_layer = torch.stack(
[self.target_networks[target[i]](target_input[i]) for i in range(this_state_num)]
)
middle_input = torch.cat([view_layer, target_layer], dim=1)
middle_layer = torch.stack(
[self.middle_networks[target[i]](middle_input[i]) for i in range(this_state_num)]
)
criticV = torch.stack(
[self.critic[target[i]](hidden_output[i]) for i in range(state.size()[0])]
)
[self.critic[target[i]](middle_layer[i]) for i in range(this_state_num)]
) # self.critic
return criticV
def get_actions_value(self, state: torch.Tensor, actions=None):
# get actions and value
target = state[:, 0].to(torch.int32) # int
hidden_output = torch.stack(
[self.hidden_networks[target[i]](state[i]) for i in range(target.size()[0])]
this_state_num = target.size()[0]
view_input = state[:, -self.ray_state_size:] # all ray input
target_input = state[:, : self.state_size_without_ray]
view_layer = self.view_network(view_input)
target_layer = torch.stack(
[self.target_networks[target[i]](target_input[i]) for i in range(this_state_num)]
)
middle_input = torch.cat([view_layer, target_layer], dim=1)
middle_layer = torch.stack(
[self.middle_networks[target[i]](middle_input[i]) for i in range(this_state_num)]
)
# discrete
# 递归targets的数量,既agent数来实现根据target不同来选用对应的输出网络计算输出
dis_logits = torch.stack(
[self.actor_dis[target[i]](hidden_output[i]) for i in range(target.size()[0])]
[self.actor_dis[target[i]](middle_layer[i]) for i in range(this_state_num)]
)
split_logits = torch.split(dis_logits, self.discrete_shape, dim=1)
multi_categoricals = [Categorical(logits=thisLogits) for thisLogits in split_logits]
# continuous
actions_mean = torch.stack(
[self.actor_mean[target[i]](hidden_output[i]) for i in range(target.size()[0])]
[self.actor_mean[target[i]](middle_layer[i]) for i in range(this_state_num)]
) # self.actor_mean(hidden)
action_logstd = torch.stack(
[torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(target.size()[0])]
[torch.squeeze(self.actor_logstd[target[i]], 0) for i in range(this_state_num)]
)
# print(action_logstd)
action_std = torch.exp(action_logstd) # torch.exp(action_logstd)
con_probs = Normal(actions_mean, action_std)
# critic
criticV = torch.stack(
[self.critic[target[i]](hidden_output[i]) for i in range(target.size()[0])]
[self.critic[target[i]](middle_layer[i]) for i in range(this_state_num)]
) # self.critic
if actions is None: