diff --git a/mcts.py b/mcts.py index 8fd842edc0f2ba6c4f9093ae539d57ee0b11479f..935fc89315443a4048564714bcd26570804c3eae 100644 --- a/mcts.py +++ b/mcts.py @@ -84,7 +84,7 @@ def mcts_evaluation(depth, reward_list = [] trial_termination_reason_counter = {} for num_ep in range(nb_episodes): - init_obs = options.reset() + options.reset() episode_reward = 0 start_time = time.time() t = 0 @@ -115,7 +115,7 @@ def mcts_evaluation(depth, total_time = int(end_time-start_time) if options.env.goal_achieved: num_successes += 1 - print('Episode {}: Reward = {} ({})'.format(num_ep, episode_reward, + print('Episode {}: Reward = {:.2f} ({})'.format(num_ep, episode_reward, datetime.timedelta(seconds=total_time))) reward_list += [episode_reward] @@ -125,9 +125,8 @@ def mcts_evaluation(depth, else: overall_termination_reason_list[reason] = [count] - print("\nTrial {}: Reward = (Avg: {}, Std: {}), Successes: {}/{}".\ - format(num_tr, np.mean(reward_list), np.std(reward_list), \ - num_successes, nb_episodes)) + print("\nTrial {}: Reward = (Avg: {:.2f}, Std: {:.2f}), Successes: {}/{}".\ + format(num_tr, np.mean(reward_list), np.std(reward_list), num_successes, nb_episodes)) print("Trial {} Termination reason(s):".format(num_tr)) for reason, count in trial_termination_reason_counter.items(): print("{}: {}".format(reason, count)) @@ -137,7 +136,7 @@ def mcts_evaluation(depth, overall_success_percent_list += [num_successes * 100.0 / nb_episodes] print("===========================") - print('Overall: Reward = (Avg: {}, Std: {}), Success = (Avg: {}, Std: {})\n'.\ + print('Overall: Reward = (Avg: {:.2f}, Std: {:.2f}), Success = (Avg: {:.2f}, Std: {:.2f})\n'.\ format(np.mean(overall_reward_list), np.std(overall_reward_list), np.mean(overall_success_percent_list), np.std(overall_success_percent_list))) @@ -147,8 +146,8 @@ def mcts_evaluation(depth, while count_list.size != nb_trials: count_list = np.append(count_list, 0) - print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list), - np.std(count_list))) + print("{}: Avg: {:.2f}, Std: {:.2f}".format(reason, np.mean(count_list), np.std(count_list))) + if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/worlds/simple_intersection/shapes.py b/worlds/simple_intersection/shapes.py index d814290c023417716385c3ba45b1cd340d01ee88..f3fd3f8f5cd999fdc8e90ce25c455424d2537b0b 100644 --- a/worlds/simple_intersection/shapes.py +++ b/worlds/simple_intersection/shapes.py @@ -103,8 +103,7 @@ class Image(Shape): group), "car" (represents a car) """ - self.image = pyglet.image.load('env/simple_intersection/graphics/' + - image_url) + self.image = pyglet.image.load('worlds/simple_intersection/graphics/' + image_url) if anchor is None: self.image.anchor_x = self.image.width // 2