Compare commits

...

1 Commits

Author SHA1 Message Date
10a1663230 将Tensor改为tensor
Tensor与tensor的问题,规范化tensor使用。
2023-08-08 20:49:23 +09:00
4 changed files with 9 additions and 9 deletions

View File

@ -112,7 +112,7 @@ if __name__ == "__main__":
with torch.no_grad(): with torch.no_grad():
# predict actions # predict actions
action, dis_logprob, _, con_logprob, _, value = agent.get_actions_value( action, dis_logprob, _, con_logprob, _, value = agent.get_actions_value(
torch.Tensor(state).to(device) torch.tensor(state,dtype=torch.float32).to(device)
) )
value = value.flatten() value = value.flatten()

View File

@ -61,7 +61,7 @@ class PPOMem:
thisRewardBF = (np.asarray(thisRewardBF) + (remainTime * self.result_broadcast_ratio)).tolist() thisRewardBF = (np.asarray(thisRewardBF) + (remainTime * self.result_broadcast_ratio)).tolist()
else: else:
print("!!!!!DIDNT GET RESULT REWARD!!!!!!", rewardBF[-1]) print("!!!!!DIDNT GET RESULT REWARD!!!!!!", rewardBF[-1])
return torch.Tensor(thisRewardBF).to(self.device) return torch.tensor(thisRewardBF,dtype=torch.float32).to(self.device)
def save_memories( def save_memories(
self, self,
@ -101,10 +101,10 @@ class PPOMem:
thisRewardsTensor = self.broad_cast_end_reward(self.rewards_bf[i], remainTime) thisRewardsTensor = self.broad_cast_end_reward(self.rewards_bf[i], remainTime)
adv, rt = agent.gae( adv, rt = agent.gae(
rewards=thisRewardsTensor, rewards=thisRewardsTensor,
dones=torch.Tensor(self.dones_bf[i]).to(self.device), dones=torch.tensor(self.dones_bf[i],dtype=torch.float32).to(self.device),
values=torch.tensor(self.values_bf[i]).to(self.device), values=torch.tensor(self.values_bf[i]).to(self.device),
next_obs=torch.tensor(next_state[i]).to(self.device).unsqueeze(0), next_obs=torch.tensor(next_state[i]).to(self.device).unsqueeze(0),
next_done=torch.Tensor([next_done[i]]).to(self.device), next_done=torch.tensor([next_done[i]],dtype=torch.float32).to(self.device),
) )
# send memories to training datasets # send memories to training datasets
self.obs[roundTargetType] = torch.cat((self.obs[roundTargetType], torch.tensor(np.array(self.ob_bf[i])).to(self.device)), 0) self.obs[roundTargetType] = torch.cat((self.obs[roundTargetType], torch.tensor(np.array(self.ob_bf[i])).to(self.device)), 0)

View File

@ -34,7 +34,7 @@ BROADCASTREWARD = False
ANNEAL_LEARNING_RATE = True ANNEAL_LEARNING_RATE = True
CLIP_VLOSS = True CLIP_VLOSS = True
NORM_ADV = False NORM_ADV = False
TRAIN = False TRAIN = True
SAVE_MODEL = False SAVE_MODEL = False
WANDB_TACK = False WANDB_TACK = False
LOAD_DIR = None LOAD_DIR = None

View File

@ -275,8 +275,8 @@ class PPOAgent(nn.Module):
self, self,
rewards: torch.Tensor, rewards: torch.Tensor,
dones: torch.Tensor, dones: torch.Tensor,
values: torch.tensor, values: torch.Tensor,
next_obs: torch.tensor, next_obs: torch.Tensor,
next_done: torch.Tensor, next_done: torch.Tensor,
) -> tuple: ) -> tuple:
# GAE # GAE