diff --git a/env/simple_intersection/simple_intersection_env.py b/env/simple_intersection/simple_intersection_env.py index 2cf5edf09893ad20832a572a1935f8680ca6f016..c74a1612d1625609a4f850bd44724e888ee99286 100644 --- a/env/simple_intersection/simple_intersection_env.py +++ b/env/simple_intersection/simple_intersection_env.py @@ -1213,3 +1213,12 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): self.window = backup_window return new_env + + @property + def n_others_with_higher_priority(self): + n_others_with_higher_priority = 0 + for veh in self.vehs[1:]: + if veh.waited_count > self.ego.waited_count: + n_others_with_higher_priority += 1 + + return n_others_with_higher_priority diff --git a/options/simple_intersection/maneuver_base.py b/options/simple_intersection/maneuver_base.py index ebed07ca94f090af30d33438fe2d384dd3df244b..4b47f956777f8475504aa8b99806be4653ee24a7 100644 --- a/options/simple_intersection/maneuver_base.py +++ b/options/simple_intersection/maneuver_base.py @@ -273,9 +273,9 @@ class ManeuverBase(EpisodicEnvBase): def generate_learning_scenario(self): raise NotImplemented(self.__class__.__name__ + ".generate_learning_scenario is not implemented.") - def generate_validation_scenario(self): - # If not implemented, use learning scenario. + def generate_validation_scenario(self): # Override this method in the subclass if some customization is needed. self.generate_learning_scenario() + self._enable_low_level_training_properties = False def generate_scenario(self, enable_LTL_preconditions=True, timeout=np.infty, **kwargs): """generates the scenario for low-level policy learning and validation. This method diff --git a/options/simple_intersection/maneuvers.py b/options/simple_intersection/maneuvers.py index bb6a7649a80313b0900291da7f23a1dc6f35fb39..8e2f58971316e0f45d651999b07a8567ec8dbd7e 100644 --- a/options/simple_intersection/maneuvers.py +++ b/options/simple_intersection/maneuvers.py @@ -24,7 +24,7 @@ class KeepLane(ManeuverBase): def generate_learning_scenario(self): self.generate_scenario(enable_LTL_preconditions=False, - ego_pos_range=(rd.hlanes.start_pos, rd.hlanes.end_pos), + ego_pos_range=(rd.hlanes.near_stop_region, rd.hlanes.end_pos), ego_perturb_lim=(rd.hlanes.width / 4, np.pi / 6), ego_heading_towards_lane_centre=True) # the goal reward and termination is led by the SimpleIntersectionEnv @@ -103,7 +103,7 @@ class Wait(ManeuverBase): def _init_LTL_preconditions(self): self._LTL_preconditions.append( - LTLProperty("G ( (in_stop_region and stopped_now) U highest_priority)", 0)) + LTLProperty("G ( (in_stop_region and stopped_now) U highest_priority )", 0)) self._LTL_preconditions.append( LTLProperty("G ( not (in_intersection and highest_priority) )", @@ -113,15 +113,14 @@ class Wait(ManeuverBase): ego = self.env.ego self._v_ref = rd.speed_limit if self.env.ego.APs['highest_priority'] else 0 self._target_lane = ego.APs['lane'] - - self._n_others_with_higher_priority = 0 - for veh in self.env.vehs[1:]: - if veh.waited_count > ego.waited_count: - self._n_others_with_higher_priority += 1 + self._ego_stop_count = 0 def _update_param(self): if self.env.ego.APs['highest_priority']: self._v_ref = rd.speed_limit + if self._enable_low_level_training_properties: + if self.env.n_others_with_higher_priority == 0: + self._ego_stop_count += 1 def generate_learning_scenario(self): n_others = np.random.randint(0, 3) @@ -141,6 +140,19 @@ class Wait(ManeuverBase): self.env.init_APs(False) self._reward_in_goal = 200 + self._extra_r_on_timeout = -200 + self._enable_low_level_training_properties = True + self._ego_stop_count = 0 + + @property + def extra_termination_condition(self): + if self._enable_low_level_training_properties: # activated only for the low-level training. + if self._ego_stop_count >= 50: + self._extra_r_terminal = -200 + return True + else: + self._extra_r_terminal = None + return False @staticmethod def _features_dim_reduction(features_tuple):