Skip to content
Snippets Groups Projects
Commit 2e933764 authored by Jaeyoung Lee's avatar Jaeyoung Lee
Browse files

Output format change in mcts.py and misc.

parent 52dcb599
No related branches found
No related tags found
1 merge request!5Bug fix and improve MCTS and others
......@@ -84,7 +84,7 @@ def mcts_evaluation(depth,
reward_list = []
trial_termination_reason_counter = {}
for num_ep in range(nb_episodes):
init_obs = options.reset()
options.reset()
episode_reward = 0
start_time = time.time()
t = 0
......@@ -115,7 +115,7 @@ def mcts_evaluation(depth,
total_time = int(end_time-start_time)
if options.env.goal_achieved:
num_successes += 1
print('Episode {}: Reward = {} ({})'.format(num_ep, episode_reward,
print('Episode {}: Reward = {:.2f} ({})'.format(num_ep, episode_reward,
datetime.timedelta(seconds=total_time)))
reward_list += [episode_reward]
......@@ -125,9 +125,8 @@ def mcts_evaluation(depth,
else:
overall_termination_reason_list[reason] = [count]
print("\nTrial {}: Reward = (Avg: {}, Std: {}), Successes: {}/{}".\
format(num_tr, np.mean(reward_list), np.std(reward_list), \
num_successes, nb_episodes))
print("\nTrial {}: Reward = (Avg: {:.2f}, Std: {:.2f}), Successes: {}/{}".\
format(num_tr, np.mean(reward_list), np.std(reward_list), num_successes, nb_episodes))
print("Trial {} Termination reason(s):".format(num_tr))
for reason, count in trial_termination_reason_counter.items():
print("{}: {}".format(reason, count))
......@@ -137,7 +136,7 @@ def mcts_evaluation(depth,
overall_success_percent_list += [num_successes * 100.0 / nb_episodes]
print("===========================")
print('Overall: Reward = (Avg: {}, Std: {}), Success = (Avg: {}, Std: {})\n'.\
print('Overall: Reward = (Avg: {:.2f}, Std: {:.2f}), Success = (Avg: {:.2f}, Std: {:.2f})\n'.\
format(np.mean(overall_reward_list), np.std(overall_reward_list),
np.mean(overall_success_percent_list), np.std(overall_success_percent_list)))
......@@ -147,8 +146,8 @@ def mcts_evaluation(depth,
while count_list.size != nb_trials:
count_list = np.append(count_list, 0)
print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
np.std(count_list)))
print("{}: Avg: {:.2f}, Std: {:.2f}".format(reason, np.mean(count_list), np.std(count_list)))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
......@@ -103,8 +103,7 @@ class Image(Shape):
group), "car" (represents a car)
"""
self.image = pyglet.image.load('env/simple_intersection/graphics/' +
image_url)
self.image = pyglet.image.load('worlds/simple_intersection/graphics/' + image_url)
if anchor is None:
self.image.anchor_x = self.image.width // 2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment