Skip to content
Snippets Groups Projects
Commit a741a9c5 authored by Jaeyoung Lee's avatar Jaeyoung Lee
Browse files

Improve MCTS.

- Fixed a bug of starting with start_node;
- Fixed the miscalculation of mean and std;
- Added timeout to prevent infinite-loop.
parent 67b0a673
No related branches found
No related tags found
1 merge request!5Bug fix and improve MCTS and others
......@@ -106,4 +106,4 @@ class MCTSController(ControllerBase):
if self.debug:
print('Btw: %s' % str(p))
self.mcts.tree.reconstruct(node_after_transition)
self.set_current_node(node_after_transition)
\ No newline at end of file
self.set_current_node(node_after_transition)
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import DT
from options.options_loader import OptionsGraph
from backends import DDPGLearner, DQNLearner, MCTSLearner
import numpy as np
......@@ -73,10 +74,10 @@ def mcts_evaluation(depth,
debug=debug)
# Evaluate
print("\nConducting {} trials of {} episodes each".format(
nb_trials, nb_episodes))
print("\nConducting {} trials of {} episodes each".format(nb_trials, nb_episodes))
timeout = 40 # 40 sec. timeout for each episode
overall_reward_list = []
overall_success_accuracy = []
overall_success_percent_list = []
overall_termination_reason_list = {}
for num_tr in range(nb_trials):
num_successes = 0
......@@ -85,28 +86,31 @@ def mcts_evaluation(depth,
for num_ep in range(nb_episodes):
init_obs = options.reset()
episode_reward = 0
first_time = True
start_time = time.time()
while not options.env.is_terminal():
if first_time:
first_time = False
else:
# print('Stepping through ...')
features, R, terminal, info = options.controller.\
step_current_node(visualize_low_level_steps=visualize)
episode_reward += R
# print('Intermediate Reward: %f (ego x = %f)' %
# (R, options.env.vehs[0].x))
# print('')
if terminal:
if 'episode_termination_reason' in info:
termination_reason = info['episode_termination_reason']
if termination_reason in trial_termination_reason_counter:
trial_termination_reason_counter[termination_reason] += 1
else:
trial_termination_reason_counter[termination_reason] = 1
t = 0
while True:
if options.controller.can_transition():
options.controller.do_transition()
features, R, terminal, info = options.controller.step_current_node(visualize_low_level_steps=visualize)
episode_reward += R
t += DT
# print('Intermediate Reward: %f (ego x = %f)' %
# (R, options.env.vehs[0].x))
# print('')
if terminal or t > timeout:
if t > timeout:
info['episode_termination_reason'] = 'timeout'
if 'episode_termination_reason' in info:
termination_reason = info['episode_termination_reason']
if termination_reason in trial_termination_reason_counter:
trial_termination_reason_counter[termination_reason] += 1
else:
trial_termination_reason_counter[termination_reason] = 1
break
end_time = time.time()
total_time = int(end_time-start_time)
if options.env.goal_achieved:
......@@ -125,23 +129,24 @@ def mcts_evaluation(depth,
format(num_tr, np.mean(reward_list), np.std(reward_list), \
num_successes, nb_episodes))
print("Trial {} Termination reason(s):".format(num_tr))
for reason, count_list in trial_termination_reason_counter.items():
count_list = np.array(count_list)
print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
np.std(count_list)))
for reason, count in trial_termination_reason_counter.items():
print("{}: {}".format(reason, count))
print("\n")
overall_reward_list += reward_list
overall_success_accuracy += [num_successes * 1.0 / nb_episodes]
overall_success_percent_list += [num_successes * 100.0 / nb_episodes]
print("===========================")
print('Overall: Reward = (Avg: {}, Std: {}), Success = (Avg: {}, Std: {})\n'.\
format(np.mean(overall_reward_list), np.std(overall_reward_list),
np.mean(overall_success_accuracy), np.std(overall_success_accuracy)))
np.mean(overall_success_percent_list), np.std(overall_success_percent_list)))
print("Termination reason(s):")
for reason, count_list in overall_termination_reason_list.items():
count_list = np.array(count_list)
while count_list.size != nb_trials:
count_list = np.append(count_list, 0)
print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
np.std(count_list)))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment