From 5c3b366b92aa1d8c2547ec39fc3bb7028e02d756 Mon Sep 17 00:00:00 2001
From: Jaeyoung Lee <jaeyoung.lee@uwaterloo.ca>
Date: Sun, 18 Nov 2018 03:01:59 -0500
Subject: [PATCH] Improve Wait and KeepLane, minor changes...

---
 .../simple_intersection_env.py                |  9 +++++++
 options/simple_intersection/maneuver_base.py  |  4 +--
 options/simple_intersection/maneuvers.py      | 26 ++++++++++++++-----
 3 files changed, 30 insertions(+), 9 deletions(-)

diff --git a/env/simple_intersection/simple_intersection_env.py b/env/simple_intersection/simple_intersection_env.py
index 2cf5edf..c74a161 100644
--- a/env/simple_intersection/simple_intersection_env.py
+++ b/env/simple_intersection/simple_intersection_env.py
@@ -1213,3 +1213,12 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
         self.window = backup_window
 
         return new_env
+
+    @property
+    def n_others_with_higher_priority(self):
+        n_others_with_higher_priority = 0
+        for veh in self.vehs[1:]:
+            if veh.waited_count > self.ego.waited_count:
+                n_others_with_higher_priority += 1
+
+        return n_others_with_higher_priority
diff --git a/options/simple_intersection/maneuver_base.py b/options/simple_intersection/maneuver_base.py
index ebed07c..4b47f95 100644
--- a/options/simple_intersection/maneuver_base.py
+++ b/options/simple_intersection/maneuver_base.py
@@ -273,9 +273,9 @@ class ManeuverBase(EpisodicEnvBase):
     def generate_learning_scenario(self):
         raise NotImplemented(self.__class__.__name__ + ".generate_learning_scenario is not implemented.")
 
-    def generate_validation_scenario(self):
-        # If not implemented, use learning scenario.
+    def generate_validation_scenario(self):  # Override this method in the subclass if some customization is needed.
         self.generate_learning_scenario()
+        self._enable_low_level_training_properties = False
 
     def generate_scenario(self, enable_LTL_preconditions=True, timeout=np.infty, **kwargs):
         """generates the scenario for low-level policy learning and validation. This method
diff --git a/options/simple_intersection/maneuvers.py b/options/simple_intersection/maneuvers.py
index bb6a764..8e2f589 100644
--- a/options/simple_intersection/maneuvers.py
+++ b/options/simple_intersection/maneuvers.py
@@ -24,7 +24,7 @@ class KeepLane(ManeuverBase):
 
     def generate_learning_scenario(self):
         self.generate_scenario(enable_LTL_preconditions=False,
-                               ego_pos_range=(rd.hlanes.start_pos, rd.hlanes.end_pos),
+                               ego_pos_range=(rd.hlanes.near_stop_region, rd.hlanes.end_pos),
                                ego_perturb_lim=(rd.hlanes.width / 4, np.pi / 6),
                                ego_heading_towards_lane_centre=True)
         # the goal reward and termination is led by the SimpleIntersectionEnv
@@ -103,7 +103,7 @@ class Wait(ManeuverBase):
 
     def _init_LTL_preconditions(self):
         self._LTL_preconditions.append(
-            LTLProperty("G ( (in_stop_region and stopped_now) U highest_priority)", 0))
+            LTLProperty("G ( (in_stop_region and stopped_now) U highest_priority )", 0))
 
         self._LTL_preconditions.append(
             LTLProperty("G ( not (in_intersection and highest_priority) )",
@@ -113,15 +113,14 @@ class Wait(ManeuverBase):
         ego = self.env.ego
         self._v_ref = rd.speed_limit if self.env.ego.APs['highest_priority'] else 0
         self._target_lane = ego.APs['lane']
-
-        self._n_others_with_higher_priority = 0
-        for veh in self.env.vehs[1:]:
-            if veh.waited_count > ego.waited_count:
-                self._n_others_with_higher_priority += 1
+        self._ego_stop_count = 0
 
     def _update_param(self):
         if self.env.ego.APs['highest_priority']:
             self._v_ref = rd.speed_limit
+            if self._enable_low_level_training_properties:
+                if self.env.n_others_with_higher_priority == 0:
+                    self._ego_stop_count += 1
 
     def generate_learning_scenario(self):
         n_others = np.random.randint(0, 3)
@@ -141,6 +140,19 @@ class Wait(ManeuverBase):
         self.env.init_APs(False)
 
         self._reward_in_goal = 200
+        self._extra_r_on_timeout = -200
+        self._enable_low_level_training_properties = True
+        self._ego_stop_count = 0
+
+    @property
+    def extra_termination_condition(self):
+        if self._enable_low_level_training_properties:  # activated only for the low-level training.
+            if self._ego_stop_count >= 50:
+                self._extra_r_terminal = -200
+                return True
+            else:
+                self._extra_r_terminal = None
+                return False
 
     @staticmethod
     def _features_dim_reduction(features_tuple):
-- 
GitLab