From 5c3b366b92aa1d8c2547ec39fc3bb7028e02d756 Mon Sep 17 00:00:00 2001 From: Jaeyoung Lee <jaeyoung.lee@uwaterloo.ca> Date: Sun, 18 Nov 2018 03:01:59 -0500 Subject: [PATCH] Improve Wait and KeepLane, minor changes... --- .../simple_intersection_env.py | 9 +++++++ options/simple_intersection/maneuver_base.py | 4 +-- options/simple_intersection/maneuvers.py | 26 ++++++++++++++----- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/env/simple_intersection/simple_intersection_env.py b/env/simple_intersection/simple_intersection_env.py index 2cf5edf..c74a161 100644 --- a/env/simple_intersection/simple_intersection_env.py +++ b/env/simple_intersection/simple_intersection_env.py @@ -1213,3 +1213,12 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): self.window = backup_window return new_env + + @property + def n_others_with_higher_priority(self): + n_others_with_higher_priority = 0 + for veh in self.vehs[1:]: + if veh.waited_count > self.ego.waited_count: + n_others_with_higher_priority += 1 + + return n_others_with_higher_priority diff --git a/options/simple_intersection/maneuver_base.py b/options/simple_intersection/maneuver_base.py index ebed07c..4b47f95 100644 --- a/options/simple_intersection/maneuver_base.py +++ b/options/simple_intersection/maneuver_base.py @@ -273,9 +273,9 @@ class ManeuverBase(EpisodicEnvBase): def generate_learning_scenario(self): raise NotImplemented(self.__class__.__name__ + ".generate_learning_scenario is not implemented.") - def generate_validation_scenario(self): - # If not implemented, use learning scenario. + def generate_validation_scenario(self): # Override this method in the subclass if some customization is needed. self.generate_learning_scenario() + self._enable_low_level_training_properties = False def generate_scenario(self, enable_LTL_preconditions=True, timeout=np.infty, **kwargs): """generates the scenario for low-level policy learning and validation. This method diff --git a/options/simple_intersection/maneuvers.py b/options/simple_intersection/maneuvers.py index bb6a764..8e2f589 100644 --- a/options/simple_intersection/maneuvers.py +++ b/options/simple_intersection/maneuvers.py @@ -24,7 +24,7 @@ class KeepLane(ManeuverBase): def generate_learning_scenario(self): self.generate_scenario(enable_LTL_preconditions=False, - ego_pos_range=(rd.hlanes.start_pos, rd.hlanes.end_pos), + ego_pos_range=(rd.hlanes.near_stop_region, rd.hlanes.end_pos), ego_perturb_lim=(rd.hlanes.width / 4, np.pi / 6), ego_heading_towards_lane_centre=True) # the goal reward and termination is led by the SimpleIntersectionEnv @@ -103,7 +103,7 @@ class Wait(ManeuverBase): def _init_LTL_preconditions(self): self._LTL_preconditions.append( - LTLProperty("G ( (in_stop_region and stopped_now) U highest_priority)", 0)) + LTLProperty("G ( (in_stop_region and stopped_now) U highest_priority )", 0)) self._LTL_preconditions.append( LTLProperty("G ( not (in_intersection and highest_priority) )", @@ -113,15 +113,14 @@ class Wait(ManeuverBase): ego = self.env.ego self._v_ref = rd.speed_limit if self.env.ego.APs['highest_priority'] else 0 self._target_lane = ego.APs['lane'] - - self._n_others_with_higher_priority = 0 - for veh in self.env.vehs[1:]: - if veh.waited_count > ego.waited_count: - self._n_others_with_higher_priority += 1 + self._ego_stop_count = 0 def _update_param(self): if self.env.ego.APs['highest_priority']: self._v_ref = rd.speed_limit + if self._enable_low_level_training_properties: + if self.env.n_others_with_higher_priority == 0: + self._ego_stop_count += 1 def generate_learning_scenario(self): n_others = np.random.randint(0, 3) @@ -141,6 +140,19 @@ class Wait(ManeuverBase): self.env.init_APs(False) self._reward_in_goal = 200 + self._extra_r_on_timeout = -200 + self._enable_low_level_training_properties = True + self._ego_stop_count = 0 + + @property + def extra_termination_condition(self): + if self._enable_low_level_training_properties: # activated only for the low-level training. + if self._ego_stop_count >= 50: + self._extra_r_terminal = -200 + return True + else: + self._extra_r_terminal = None + return False @staticmethod def _features_dim_reduction(features_tuple): -- GitLab