wrong remain Time Fix

wrong remain Time Fix, what a stupid mistake...
and fix doubled WANDB writer
This commit is contained in:
Koha9 2022-12-04 09:20:05 +09:00
parent ad9817e7a4
commit 1787872e82

View File

@ -65,6 +65,7 @@ class Targets(Enum):
Attack = 2 Attack = 2
Defence = 3 Defence = 3
Num = 4 Num = 4
STATE_REMAINTIME_POSITION = 6
BASE_WINREWARD = 999 BASE_WINREWARD = 999
BASE_LOSEREWARD = -999 BASE_LOSEREWARD = -999
TARGETNUM= 4 TARGETNUM= 4
@ -417,7 +418,7 @@ if __name__ == "__main__":
value_cpu = value.cpu().numpy() value_cpu = value.cpu().numpy()
# Environment step # Environment step
next_state, reward, next_done = env.step(action_cpu) next_state, reward, next_done = env.step(action_cpu)
remainTime = state[i,STATE_REMAINTIME_POSITION]
# save memories # save memories
for i in range(env.unity_agent_num): for i in range(env.unity_agent_num):
# save memories to buffers # save memories to buffers
@ -433,7 +434,7 @@ if __name__ == "__main__":
# compute advantage and discounted reward # compute advantage and discounted reward
#print(i,"over") #print(i,"over")
roundTargetType = int(state[i,0]) roundTargetType = int(state[i,0])
thisRewardsTensor = broadCastEndReward(rewards_bf[i],roundTargetType) thisRewardsTensor = broadCastEndReward(rewards_bf[i],remainTime)
adv, rt = GAE( adv, rt = GAE(
agent, agent,
args, args,
@ -646,7 +647,6 @@ if __name__ == "__main__":
# record rewards for plotting purposes # record rewards for plotting purposes
writer.add_scalar(f"Target{targetName}/value_loss", v_loss.item(), target_steps[thisT]) writer.add_scalar(f"Target{targetName}/value_loss", v_loss.item(), target_steps[thisT])
writer.add_scalar(f"Target{targetName}/value_loss", v_loss.item(), target_steps[thisT])
writer.add_scalar(f"Target{targetName}/dis_policy_loss", dis_pg_loss.item(), target_steps[thisT]) writer.add_scalar(f"Target{targetName}/dis_policy_loss", dis_pg_loss.item(), target_steps[thisT])
writer.add_scalar(f"Target{targetName}/con_policy_loss", con_pg_loss.item(), target_steps[thisT]) writer.add_scalar(f"Target{targetName}/con_policy_loss", con_pg_loss.item(), target_steps[thisT])
writer.add_scalar(f"Target{targetName}/total_loss", loss.item(), target_steps[thisT]) writer.add_scalar(f"Target{targetName}/total_loss", loss.item(), target_steps[thisT])