Commit 6b50d1d1 authored by Jae Young Lee's avatar Jae Young Lee

Merge branch 'bug-fix-and-improve-MCTS-and-others' into 'master'

Bug fix and improve MCTS and others

See merge request !5
parents d785363e 0e92b9ab
......@@ -106,4 +106,4 @@ class MCTSController(ControllerBase):
if self.debug:
print('Btw: %s' % str(p))
self.mcts.tree.reconstruct(node_after_transition)
self.set_current_node(node_after_transition)
\ No newline at end of file
self.set_current_node(node_after_transition)
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from worlds.simple_intersection import SimpleIntersectionEnv
from worlds.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.kerasrl_learner import DQNLearner
import os
......@@ -129,6 +129,9 @@ def evaluate_high_level_policy(nb_episodes_for_test=100,
print("Termination reason(s):")
for reason, count_list in termination_reason_list.items():
count_list = np.array(count_list)
while count_list.size != nb_trials:
count_list = np.append(count_list,0)
print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
np.std(count_list)))
......
This diff is collapsed.
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from worlds.simple_intersection import SimpleIntersectionEnv
from worlds.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.kerasrl_learner import DDPGLearner
from rl.callbacks import Callback
......
from env.simple_intersection import SimpleIntersectionEnv
from worlds.simple_intersection import SimpleIntersectionEnv
from worlds.simple_intersection.constants import DT
from options.options_loader import OptionsGraph
from backends import DDPGLearner, DQNLearner, MCTSLearner
import numpy as np
......@@ -73,45 +74,47 @@ def mcts_evaluation(depth,
debug=debug)
# Evaluate
print("\nConducting {} trials of {} episodes each".format(
nb_trials, nb_episodes))
print("\nConducting {} trials of {} episodes each".format(nb_trials, nb_episodes))
timeout = 40 # 40 sec. timeout for each episode
overall_reward_list = []
overall_success_accuracy = []
overall_success_percent_list = []
overall_termination_reason_list = {}
for num_tr in range(nb_trials):
num_successes = 0
reward_list = []
trial_termination_reason_counter = {}
for num_ep in range(nb_episodes):
init_obs = options.reset()
options.reset()
episode_reward = 0
first_time = True
start_time = time.time()
while not options.env.is_terminal():
if first_time:
first_time = False
else:
# print('Stepping through ...')
features, R, terminal, info = options.controller.\
step_current_node(visualize_low_level_steps=visualize)
episode_reward += R
# print('Intermediate Reward: %f (ego x = %f)' %
# (R, options.env.vehs[0].x))
# print('')
if terminal:
if 'episode_termination_reason' in info:
termination_reason = info['episode_termination_reason']
if termination_reason in trial_termination_reason_counter:
trial_termination_reason_counter[termination_reason] += 1
else:
trial_termination_reason_counter[termination_reason] = 1
if options.controller.can_transition():
options.controller.do_transition()
t = 0
while True:
options.controller.do_transition()
features, R, terminal, info = options.controller.step_current_node(visualize_low_level_steps=visualize)
episode_reward += R
t += DT
# print('Intermediate Reward: %f (ego x = %f)' %
# (R, options.env.vehs[0].x))
# print('')
if terminal or t > timeout:
if t > timeout:
info['episode_termination_reason'] = 'timeout'
if 'episode_termination_reason' in info:
termination_reason = info['episode_termination_reason']
if termination_reason in trial_termination_reason_counter:
trial_termination_reason_counter[termination_reason] += 1
else:
trial_termination_reason_counter[termination_reason] = 1
break
end_time = time.time()
total_time = int(end_time-start_time)
if options.env.goal_achieved:
num_successes += 1
print('Episode {}: Reward = {} ({})'.format(num_ep, episode_reward,
print('Episode {}: Reward = {:.2f} ({})'.format(num_ep, episode_reward,
datetime.timedelta(seconds=total_time)))
reward_list += [episode_reward]
......@@ -121,29 +124,29 @@ def mcts_evaluation(depth,
else:
overall_termination_reason_list[reason] = [count]
print("\nTrial {}: Reward = (Avg: {}, Std: {}), Successes: {}/{}".\
format(num_tr, np.mean(reward_list), np.std(reward_list), \
num_successes, nb_episodes))
print("\nTrial {}: Reward = (Avg: {:.2f}, Std: {:.2f}), Successes: {}/{}".\
format(num_tr, np.mean(reward_list), np.std(reward_list), num_successes, nb_episodes))
print("Trial {} Termination reason(s):".format(num_tr))
for reason, count_list in trial_termination_reason_counter.items():
count_list = np.array(count_list)
print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
np.std(count_list)))
for reason, count in trial_termination_reason_counter.items():
print("{}: {}".format(reason, count))
print("\n")
overall_reward_list += reward_list
overall_success_accuracy += [num_successes * 1.0 / nb_episodes]
overall_success_percent_list += [num_successes * 100.0 / nb_episodes]
print("===========================")
print('Overall: Reward = (Avg: {}, Std: {}), Success = (Avg: {}, Std: {})\n'.\
print('Overall: Reward = (Avg: {:.2f}, Std: {:.2f}), Success = (Avg: {:.2f}, Std: {:.2f})\n'.\
format(np.mean(overall_reward_list), np.std(overall_reward_list),
np.mean(overall_success_accuracy), np.std(overall_success_accuracy)))
np.mean(overall_success_percent_list), np.std(overall_success_percent_list)))
print("Termination reason(s):")
for reason, count_list in overall_termination_reason_list.items():
count_list = np.array(count_list)
print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
np.std(count_list)))
while count_list.size != nb_trials:
count_list = np.append(count_list, 0)
print("{}: Avg: {:.2f}, Std: {:.2f}".format(reason, np.mean(count_list), np.std(count_list)))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
import numpy as np
import gym
import env.simple_intersection.road_geokinemetry as rd
from env.simple_intersection.constants import DT, MAX_ACCELERATION, MAX_STEERING_ANGLE_RATE, MAX_STEERING_ANGLE
from env import EpisodicEnvBase
import worlds.simple_intersection.road_geokinemetry as rd
from worlds.simple_intersection.constants import DT, MAX_ACCELERATION, MAX_STEERING_ANGLE_RATE, MAX_STEERING_ANGLE
from worlds import EpisodicEnvBase
class ManeuverBase(EpisodicEnvBase):
......
from .maneuver_base import ManeuverBase
from env.simple_intersection.constants import *
import env.simple_intersection.road_geokinemetry as rd
from env.simple_intersection.features import extract_ego_features, extract_other_veh_features
from worlds.simple_intersection.constants import *
import worlds.simple_intersection.road_geokinemetry as rd
from worlds.simple_intersection.features import extract_ego_features, extract_other_veh_features
from verifier.simple_intersection import LTLProperty
import numpy as np
......@@ -94,7 +94,7 @@ class Halt(ManeuverBase):
# (currently, this functionality is implemented by "not self._enable_low_level_training_properties")
self._LTL_preconditions.append(
LTLProperty(
"G ( (veh_ahead and before_but_close_to_stop_region) U highest_priority )",
"G ( (veh_ahead and close_to_stop_region) U highest_priority )",
None, not self._enable_low_level_training_properties))
self._LTL_preconditions.append(
......@@ -166,10 +166,10 @@ class Stop(ManeuverBase):
LTLProperty("G ( not has_stopped_in_stop_region )",
self._penalty(self._reward_in_goal), not self._enable_low_level_training_properties))
# before_intersection rather than "before_but_close_to_stop_region or in_stop_region"?
# before_intersection rather than "close_to_stop_region or in_stop_region"?
self._LTL_preconditions.append(
LTLProperty(
"G ( (before_but_close_to_stop_region or in_stop_region) U has_stopped_in_stop_region )",
"G ( (close_to_stop_region or in_stop_region) U has_stopped_in_stop_region )",
self._penalty_in_violation))
self._LTL_preconditions.append(
......@@ -485,7 +485,7 @@ class Follow(ManeuverBase):
def _init_LTL_preconditions(self):
self._LTL_preconditions.append(
LTLProperty("G ( veh_ahead U (in_stop_region or before_but_close_to_stop_region ) )", self._penalty_for_out_of_range))
LTLProperty("G ( veh_ahead U (in_stop_region or close_to_stop_region ) )", self._penalty_for_out_of_range))
self._LTL_preconditions.append(
LTLProperty(
......
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from worlds.simple_intersection import SimpleIntersectionEnv
from worlds.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.baselines_learner import PPO2Agent
......
......@@ -4,7 +4,7 @@
# * in_stop_region: True if the veh. is in stop region
# * has_entered_stop_region: True if the veh. has entered or passed the stop
# region
# * before_but_close_to_stop_region: True if the veh. is before but close
# * close_to_stop_region: True if the veh. is before but close
# to the stop region
# * stopped_now: True if the veh. is now stopped
# * has_stopped_in_stop_region: True if the veh. has ever stopped in the
......@@ -26,7 +26,7 @@
AP_dict_simple_intersection = {
'in_stop_region': 0,
'has_entered_stop_region': 1,
'before_but_close_to_stop_region': 2,
'close_to_stop_region': 2,
'stopped_now': 3,
'has_stopped_in_stop_region': 4,
'in_intersection': 5,
......
......@@ -103,8 +103,7 @@ class Image(Shape):
group), "car" (represents a car)
"""
self.image = pyglet.image.load('env/simple_intersection/graphics/' +
image_url)
self.image = pyglet.image.load('worlds/simple_intersection/graphics/' + image_url)
if anchor is None:
self.image.anchor_x = self.image.width // 2
......
......@@ -2,7 +2,7 @@ import pyglet
import gym
from copy import deepcopy
from env import RoadEnv, EpisodicEnvBase
from worlds import RoadEnv, EpisodicEnvBase
from .vehicles import Vehicle
from .utilities import calculate_s, calculate_v_max
......@@ -1238,7 +1238,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
info = "Ego Attributes:" + get_APs(
self, EGO_INDEX, 'in_stop_region',
'has_entered_stop_region', 'has_stopped_in_stop_region',
'before_but_close_to_stop_region', 'intersection_is_clear',
'close_to_stop_region', 'intersection_is_clear',
'stopped_now', 'in_intersection', 'over_speed_limit',
'on_route', 'highest_priority', 'intersection_is_clear',
'veh_ahead', 'lane') + \
......
......@@ -273,7 +273,7 @@ class Vehicle(VehicleState):
self.APs[
'has_entered_stop_region'] = self.APs['has_entered_stop_region'] or self.APs['in_stop_region']
self.APs[
'before_but_close_to_stop_region'] = True if rd.hlanes.near_stop_region <= self.x < rd.hlanes.stop_region[0] else False
'close_to_stop_region'] = True if rd.hlanes.near_stop_region <= self.x < rd.hlanes.stop_region[0] else False
self.APs[
'parallel_to_lane'] = True if -0.1 <= self.theta <= 0.1 else False
......@@ -289,7 +289,7 @@ class Vehicle(VehicleState):
self.APs[
'has_entered_stop_region'] = True if rd.vlanes.stop_region[0] <= self.y else False
self.APs[
'before_but_close_to_stop_region'] = True if rd.vlanes.near_stop_region <= self.y < rd.vlanes.stop_region[0] else False
'close_to_stop_region'] = True if rd.vlanes.near_stop_region <= self.y < rd.vlanes.stop_region[0] else False
theta_v = -np.sign(rd.vlanes.start_pos) * np.pi / 2.0
self.APs[
'parallel_to_lane'] = True if theta_v - 0.1 <= self.theta <= theta_v + 0.1 else False
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment