Skip to content
Snippets Groups Projects
Commit 5c3b366b authored by Jae Young Lee's avatar Jae Young Lee
Browse files

Improve Wait and KeepLane, minor changes...

parent d1dbbd44
No related branches found
No related tags found
1 merge request!2Final test
...@@ -1213,3 +1213,12 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1213,3 +1213,12 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
self.window = backup_window self.window = backup_window
return new_env return new_env
@property
def n_others_with_higher_priority(self):
n_others_with_higher_priority = 0
for veh in self.vehs[1:]:
if veh.waited_count > self.ego.waited_count:
n_others_with_higher_priority += 1
return n_others_with_higher_priority
...@@ -273,9 +273,9 @@ class ManeuverBase(EpisodicEnvBase): ...@@ -273,9 +273,9 @@ class ManeuverBase(EpisodicEnvBase):
def generate_learning_scenario(self): def generate_learning_scenario(self):
raise NotImplemented(self.__class__.__name__ + ".generate_learning_scenario is not implemented.") raise NotImplemented(self.__class__.__name__ + ".generate_learning_scenario is not implemented.")
def generate_validation_scenario(self): def generate_validation_scenario(self): # Override this method in the subclass if some customization is needed.
# If not implemented, use learning scenario.
self.generate_learning_scenario() self.generate_learning_scenario()
self._enable_low_level_training_properties = False
def generate_scenario(self, enable_LTL_preconditions=True, timeout=np.infty, **kwargs): def generate_scenario(self, enable_LTL_preconditions=True, timeout=np.infty, **kwargs):
"""generates the scenario for low-level policy learning and validation. This method """generates the scenario for low-level policy learning and validation. This method
......
...@@ -24,7 +24,7 @@ class KeepLane(ManeuverBase): ...@@ -24,7 +24,7 @@ class KeepLane(ManeuverBase):
def generate_learning_scenario(self): def generate_learning_scenario(self):
self.generate_scenario(enable_LTL_preconditions=False, self.generate_scenario(enable_LTL_preconditions=False,
ego_pos_range=(rd.hlanes.start_pos, rd.hlanes.end_pos), ego_pos_range=(rd.hlanes.near_stop_region, rd.hlanes.end_pos),
ego_perturb_lim=(rd.hlanes.width / 4, np.pi / 6), ego_perturb_lim=(rd.hlanes.width / 4, np.pi / 6),
ego_heading_towards_lane_centre=True) ego_heading_towards_lane_centre=True)
# the goal reward and termination is led by the SimpleIntersectionEnv # the goal reward and termination is led by the SimpleIntersectionEnv
...@@ -103,7 +103,7 @@ class Wait(ManeuverBase): ...@@ -103,7 +103,7 @@ class Wait(ManeuverBase):
def _init_LTL_preconditions(self): def _init_LTL_preconditions(self):
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty("G ( (in_stop_region and stopped_now) U highest_priority)", 0)) LTLProperty("G ( (in_stop_region and stopped_now) U highest_priority )", 0))
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty("G ( not (in_intersection and highest_priority) )", LTLProperty("G ( not (in_intersection and highest_priority) )",
...@@ -113,15 +113,14 @@ class Wait(ManeuverBase): ...@@ -113,15 +113,14 @@ class Wait(ManeuverBase):
ego = self.env.ego ego = self.env.ego
self._v_ref = rd.speed_limit if self.env.ego.APs['highest_priority'] else 0 self._v_ref = rd.speed_limit if self.env.ego.APs['highest_priority'] else 0
self._target_lane = ego.APs['lane'] self._target_lane = ego.APs['lane']
self._ego_stop_count = 0
self._n_others_with_higher_priority = 0
for veh in self.env.vehs[1:]:
if veh.waited_count > ego.waited_count:
self._n_others_with_higher_priority += 1
def _update_param(self): def _update_param(self):
if self.env.ego.APs['highest_priority']: if self.env.ego.APs['highest_priority']:
self._v_ref = rd.speed_limit self._v_ref = rd.speed_limit
if self._enable_low_level_training_properties:
if self.env.n_others_with_higher_priority == 0:
self._ego_stop_count += 1
def generate_learning_scenario(self): def generate_learning_scenario(self):
n_others = np.random.randint(0, 3) n_others = np.random.randint(0, 3)
...@@ -141,6 +140,19 @@ class Wait(ManeuverBase): ...@@ -141,6 +140,19 @@ class Wait(ManeuverBase):
self.env.init_APs(False) self.env.init_APs(False)
self._reward_in_goal = 200 self._reward_in_goal = 200
self._extra_r_on_timeout = -200
self._enable_low_level_training_properties = True
self._ego_stop_count = 0
@property
def extra_termination_condition(self):
if self._enable_low_level_training_properties: # activated only for the low-level training.
if self._ego_stop_count >= 50:
self._extra_r_terminal = -200
return True
else:
self._extra_r_terminal = None
return False
@staticmethod @staticmethod
def _features_dim_reduction(features_tuple): def _features_dim_reduction(features_tuple):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment