diff --git a/backends/mcts_controller.py b/backends/mcts_controller.py index 2644419b9fe8f9ba9a85969e7e9bd2069d12b123..c5f6996a72c612ef0517cf1371e52ec97982b1da 100644 --- a/backends/mcts_controller.py +++ b/backends/mcts_controller.py @@ -106,4 +106,4 @@ class MCTSController(ControllerBase): if self.debug: print('Btw: %s' % str(p)) self.mcts.tree.reconstruct(node_after_transition) - self.set_current_node(node_after_transition) \ No newline at end of file + self.set_current_node(node_after_transition) diff --git a/mcts.py b/mcts.py index 021e37a685fe95280c529463d38fa34269096e38..7efab0bd6cc01ae466aaea913b6082f1be1f95a8 100644 --- a/mcts.py +++ b/mcts.py @@ -1,4 +1,5 @@ from env.simple_intersection import SimpleIntersectionEnv +from env.simple_intersection.constants import DT from options.options_loader import OptionsGraph from backends import DDPGLearner, DQNLearner, MCTSLearner import numpy as np @@ -73,10 +74,10 @@ def mcts_evaluation(depth, debug=debug) # Evaluate - print("\nConducting {} trials of {} episodes each".format( - nb_trials, nb_episodes)) + print("\nConducting {} trials of {} episodes each".format(nb_trials, nb_episodes)) + timeout = 40 # 40 sec. timeout for each episode overall_reward_list = [] - overall_success_accuracy = [] + overall_success_percent_list = [] overall_termination_reason_list = {} for num_tr in range(nb_trials): num_successes = 0 @@ -85,28 +86,31 @@ def mcts_evaluation(depth, for num_ep in range(nb_episodes): init_obs = options.reset() episode_reward = 0 - first_time = True start_time = time.time() - while not options.env.is_terminal(): - if first_time: - first_time = False - else: - # print('Stepping through ...') - features, R, terminal, info = options.controller.\ - step_current_node(visualize_low_level_steps=visualize) - episode_reward += R - # print('Intermediate Reward: %f (ego x = %f)' % - # (R, options.env.vehs[0].x)) - # print('') - if terminal: - if 'episode_termination_reason' in info: - termination_reason = info['episode_termination_reason'] - if termination_reason in trial_termination_reason_counter: - trial_termination_reason_counter[termination_reason] += 1 - else: - trial_termination_reason_counter[termination_reason] = 1 + t = 0 + while True: if options.controller.can_transition(): options.controller.do_transition() + + features, R, terminal, info = options.controller.step_current_node(visualize_low_level_steps=visualize) + episode_reward += R + t += DT + # print('Intermediate Reward: %f (ego x = %f)' % + # (R, options.env.vehs[0].x)) + # print('') + + if terminal or t > timeout: + if t > timeout: + info['episode_termination_reason'] = 'timeout' + + if 'episode_termination_reason' in info: + termination_reason = info['episode_termination_reason'] + if termination_reason in trial_termination_reason_counter: + trial_termination_reason_counter[termination_reason] += 1 + else: + trial_termination_reason_counter[termination_reason] = 1 + break + end_time = time.time() total_time = int(end_time-start_time) if options.env.goal_achieved: @@ -125,23 +129,24 @@ def mcts_evaluation(depth, format(num_tr, np.mean(reward_list), np.std(reward_list), \ num_successes, nb_episodes)) print("Trial {} Termination reason(s):".format(num_tr)) - for reason, count_list in trial_termination_reason_counter.items(): - count_list = np.array(count_list) - print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list), - np.std(count_list))) + for reason, count in trial_termination_reason_counter.items(): + print("{}: {}".format(reason, count)) print("\n") overall_reward_list += reward_list - overall_success_accuracy += [num_successes * 1.0 / nb_episodes] + overall_success_percent_list += [num_successes * 100.0 / nb_episodes] print("===========================") print('Overall: Reward = (Avg: {}, Std: {}), Success = (Avg: {}, Std: {})\n'.\ format(np.mean(overall_reward_list), np.std(overall_reward_list), - np.mean(overall_success_accuracy), np.std(overall_success_accuracy))) + np.mean(overall_success_percent_list), np.std(overall_success_percent_list))) print("Termination reason(s):") for reason, count_list in overall_termination_reason_list.items(): count_list = np.array(count_list) + while count_list.size != nb_trials: + count_list = np.append(count_list, 0) + print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list), np.std(count_list)))