Skip to content
Snippets Groups Projects
Commit a90b4bc5 authored by Jae Young Lee's avatar Jae Young Lee
Browse files

Add and train more low-level policies, train a high-level policy.

The high-level policy was trained without changelane maneuver but with immediatestop maneuver. Two problems remain: 1) the agent chooses changelane maneuver too frequently; 2) before the stop region, immediatestop maneuver works but was not chosen property after 2.5m the high-level policy training...
parent 70ad9bf5
No related branches found
No related tags found
No related merge requests found
Showing
with 215 additions and 116 deletions
...@@ -51,7 +51,7 @@ class ControllerBase(PolicyBase): ...@@ -51,7 +51,7 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info Returns state at end of node execution, total reward, epsiode_termination_flag, info
''' '''
# TODO: this is never called when you test high-level policy rather than train...
def step_current_node(self, visualize_low_level_steps=False): def step_current_node(self, visualize_low_level_steps=False):
total_reward = 0 total_reward = 0
self.node_terminal_state_reached = False self.node_terminal_state_reached = False
...@@ -59,6 +59,7 @@ class ControllerBase(PolicyBase): ...@@ -59,6 +59,7 @@ class ControllerBase(PolicyBase):
observation, reward, terminal, info = self.low_level_step_current_node() observation, reward, terminal, info = self.low_level_step_current_node()
if visualize_low_level_steps: if visualize_low_level_steps:
self.env.render() self.env.render()
# TODO: make the total_reward discounted....
total_reward += reward total_reward += reward
total_reward += self.current_node.high_level_extra_reward total_reward += self.current_node.high_level_extra_reward
......
...@@ -269,9 +269,12 @@ class DQNLearner(LearnerBase): ...@@ -269,9 +269,12 @@ class DQNLearner(LearnerBase):
""" """
model = Sequential() model = Sequential()
model.add(Flatten(input_shape=(1, ) + self.input_shape)) model.add(Flatten(input_shape=(1, ) + self.input_shape))
model.add(Dense(32)) #model.add(Dense(64))
model.add(Dense(64))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Dense(32)) model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(64))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Dense(self.nb_actions)) model.add(Dense(self.nb_actions))
model.add(Activation('linear')) model.add(Activation('linear'))
...@@ -435,6 +438,7 @@ class DQNAgentOverOptions(DQNAgent): ...@@ -435,6 +438,7 @@ class DQNAgentOverOptions(DQNAgent):
initiation conditions.""" initiation conditions."""
invalid_node_indices = list() invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases): for index, option_alias in enumerate(self.low_level_policy_aliases):
# TODO: Locate reset_maneuver to another place as this is a "get" function.
self.low_level_policies[option_alias].reset_maneuver() self.low_level_policies[option_alias].reset_maneuver()
if not self.low_level_policies[option_alias].initiation_condition: if not self.low_level_policies[option_alias].initiation_condition:
invalid_node_indices.append(index) invalid_node_indices.append(index)
......
...@@ -30,6 +30,7 @@ class ManualPolicy(ControllerBase): ...@@ -30,6 +30,7 @@ class ManualPolicy(ControllerBase):
new_node = None new_node = None
if self.low_level_policies[self.current_node].termination_condition: if self.low_level_policies[self.current_node].termination_condition:
for next_node in self.adj[self.current_node]: for next_node in self.adj[self.current_node]:
self.low_level_policies[next_node].reset_maneuver()
if self.low_level_policies[next_node].initiation_condition: if self.low_level_policies[next_node].initiation_condition:
new_node = next_node new_node = next_node
break # change current_node to the highest priority next node break # change current_node to the highest priority next node
......
No preview for this file type
File added
File added
File added
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"wait": "Wait", "wait": "Wait",
"follow": "Follow", "follow": "Follow",
"stop": "Stop", "stop": "Stop",
"changelane": "ChangeLane", "immediatestop": "ImmediateStop",
"keeplane": "KeepLane" "keeplane": "KeepLane"
}, },
......
...@@ -174,6 +174,6 @@ class Features(object): ...@@ -174,6 +174,6 @@ class Features(object):
# Add buffer features to make a fixed length feature vector # Add buffer features to make a fixed length feature vector
for i in range(MAX_NUM_VEHICLES - len(self.other_vehs)): for i in range(MAX_NUM_VEHICLES - len(self.other_vehs)):
feature += (0.0, 0.0, 0.0, 0.0, -1) feature += (0.0, 0.0, 0.0, 0.0, -1.0)
return feature return feature
...@@ -1156,12 +1156,13 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1156,12 +1156,13 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
Returns True if the environment has terminated Returns True if the environment has terminated
""" """
model_checks_violated = (self._LTL_preconditions_enable and \ model_checks_violated = self._LTL_preconditions_enable and \
self.current_model_checking_result()) self.current_model_checking_result()
reached_goal = self._terminate_in_goal and self.goal_achieved reached_goal = self._terminate_in_goal and self.goal_achieved
self._check_collisions() self._check_collisions()
self._check_ego_theta_out_of_range() self._check_ego_theta_out_of_range()
terminated = self.termination_condition terminated = self.termination_condition
return model_checks_violated or reached_goal or terminated return model_checks_violated or reached_goal or terminated
@property @property
...@@ -1181,7 +1182,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1181,7 +1182,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return (self.ego.x >= rd.hlanes.end_pos) and \ return (self.ego.x >= rd.hlanes.end_pos) and \
not self.collision_happened and \ not self.collision_happened and \
not self.ego.APs['over_speed_limit'] (self.ego.v <= 1.1*rd.speed_limit)
def reset(self): def reset(self):
"""Gym compliant reset function. """Gym compliant reset function.
...@@ -1229,7 +1230,6 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1229,7 +1230,6 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
self.window.dispatch_events() self.window.dispatch_events()
# Text information about ego vehicle's states # Text information about ego vehicle's states
# Right now, we are only training one option (Stop)
info = "Ego Attributes:" + get_APs( info = "Ego Attributes:" + get_APs(
self, EGO_INDEX, 'in_stop_region', self, EGO_INDEX, 'in_stop_region',
'has_entered_stop_region', 'has_stopped_in_stop_region', 'has_entered_stop_region', 'has_stopped_in_stop_region',
......
...@@ -11,7 +11,7 @@ def high_level_policy_training(nb_steps=25000, ...@@ -11,7 +11,7 @@ def high_level_policy_training(nb_steps=25000,
load_weights=False, load_weights=False,
training=True, training=True,
testing=True, testing=True,
nb_episodes_for_test=10, nb_episodes_for_test=20,
max_nb_steps=100, max_nb_steps=100,
visualize=False, visualize=False,
tensorboard=False, tensorboard=False,
...@@ -63,8 +63,7 @@ def high_level_policy_training(nb_steps=25000, ...@@ -63,8 +63,7 @@ def high_level_policy_training(nb_steps=25000,
agent.save_model(save_path) agent.save_model(save_path)
if testing: if testing:
options.set_controller_policy(agent.predict) high_level_policy_testing(nb_episodes_for_test=nb_episodes_for_test)
agent.test_model(options, nb_episodes=nb_episodes_for_test)
return agent return agent
...@@ -228,7 +227,6 @@ if __name__ == "__main__": ...@@ -228,7 +227,6 @@ if __name__ == "__main__":
load_weights=args.load_weights, load_weights=args.load_weights,
save_path=args.save_file, save_path=args.save_file,
tensorboard=args.tensorboard, tensorboard=args.tensorboard,
nb_episodes_for_test=20,
visualize=args.visualize) visualize=args.visualize)
if args.test: if args.test:
......
import json import json
import os # for the use of os.path.isfile
from .simple_intersection.maneuvers import * from .simple_intersection.maneuvers import *
from .simple_intersection.mcts_maneuvers import * from .simple_intersection.mcts_maneuvers import *
from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSController, ManualPolicy from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSController, ManualPolicy
...@@ -155,19 +156,34 @@ class OptionsGraph: ...@@ -155,19 +156,34 @@ class OptionsGraph:
# TODO: error handling # TODO: error handling
def load_trained_low_level_policies(self): def load_trained_low_level_policies(self):
for key, maneuver in self.maneuvers.items(): for key, maneuver in self.maneuvers.items():
agent = DDPGLearner( trained_policy_path = "backends/trained_policies/" + key + "/"
input_shape=(maneuver.get_reduced_feature_length(), ), critic_file_exists = os.path.isfile(trained_policy_path + key + "_weights_critic.h5f")
nb_actions=2, actor_file_exists = os.path.isfile(trained_policy_path + key + "_weights_actor.h5f")
gamma=0.99,
nb_steps_warmup_critic=200, if actor_file_exists and critic_file_exists:
nb_steps_warmup_actor=200, agent = DDPGLearner(
lr=1e-3) input_shape=(maneuver.get_reduced_feature_length(),),
agent.load_model("backends/trained_policies/" + key + "/" + key + nb_actions=2,
"_weights.h5f") gamma=0.99,
maneuver.set_low_level_trained_policy(agent.predict) nb_steps_warmup_critic=200,
maneuver._cost_weights = (20.0 * 1e-3, 1.0 * 1e-3, 0.25 * 1e-3, nb_steps_warmup_actor=200,
1.0 * 1e-3, 100.0 * 1e-3, 0.1 * 1e-3, lr=1e-3)
0.25 * 1e-3, 0.1 * 1e-3) agent.load_model(trained_policy_path + key + "_weights.h5f")
maneuver.set_low_level_trained_policy(agent.predict)
elif not critic_file_exists and actor_file_exists:
print("\n Warning: unable to load the low-level policy of \"" + key +
"\". the file of critic weights have to be located in the same " +
"directory of the actor weights file; the manual policy will be used instead.\n")
else:
print("\n Warning: the trained low-level policy of \"" + key +
"\" does not exists; the manual policy will be used.\n")
# setting the cost weights for high-level policy training.
# TODO: this shouldn't be initialized here, but within ManeuverBase class (e.g. make some flag indicating high-level training and set the weights)...
maneuver._cost_weights = (100.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.25 * 1e-3, 0.1 * 1e-3)
if self.config["method"] == "mcts": if self.config["method"] == "mcts":
maneuver.timeout = np.inf maneuver.timeout = np.inf
......
...@@ -20,7 +20,7 @@ class ManeuverBase(EpisodicEnvBase): ...@@ -20,7 +20,7 @@ class ManeuverBase(EpisodicEnvBase):
learning_mode = 'training' learning_mode = 'training'
#: timeout (i.e., time horizon for termination) #: timeout (i.e., time horizon for termination)
# By default, the time-out horizon is 1 as in Paxton et. al (2017). # By default, the time-out horizon is 1.
timeout = 1 timeout = 1
#: the option specific weight vector for cost of driving, which is #: the option specific weight vector for cost of driving, which is
...@@ -153,8 +153,7 @@ class ManeuverBase(EpisodicEnvBase): ...@@ -153,8 +153,7 @@ class ManeuverBase(EpisodicEnvBase):
# in this case, no additional reward by Default # in this case, no additional reward by Default
# (i.e., self._extra_r_terminal = None by default). # (i.e., self._extra_r_terminal = None by default).
self._terminal_reward_superposition(self._extra_r_terminal) self._terminal_reward_superposition(self._extra_r_terminal)
info[ info['maneuver_termination_reason'] = 'extra_termination_condition'
'maneuver_termination_reason'] = 'extra_termination_condition'
if self.timeout_happened: if self.timeout_happened:
if self._give_reward_on_timeout: if self._give_reward_on_timeout:
# in this case, no additional reward by Default # in this case, no additional reward by Default
......
...@@ -16,9 +16,11 @@ class KeepLane(ManeuverBase): ...@@ -16,9 +16,11 @@ class KeepLane(ManeuverBase):
def _init_LTL_preconditions(self): def _init_LTL_preconditions(self):
self._LTL_preconditions.append(LTLProperty("G ( not veh_ahead )", 0)) self._LTL_preconditions.append(LTLProperty("G ( not veh_ahead )", 0))
#self._LTL_preconditions.append( #self._LTL_preconditions.append(
# LTLProperty("G ( not stopped_now )", 100, # LTLProperty("G ( not stopped_now )", 100,
# self._enable_low_level_training_properties)) # self._enable_low_level_training_properties))
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty( LTLProperty(
"G ( (lane and target_lane) or (not lane and not target_lane) )", "G ( (lane and target_lane) or (not lane and not target_lane) )",
...@@ -63,6 +65,76 @@ class KeepLane(ManeuverBase): ...@@ -63,6 +65,76 @@ class KeepLane(ManeuverBase):
return False return False
class ImmediateStop(ManeuverBase):
_terminate_in_goal = True
_reward_in_goal = None
_penalty_in_violation = None
_ego_pos_range = (rd.intersection_width_w_offset, rd.hlanes.end_pos)
def _init_param(self):
self._v_ref = 0 if self._enable_low_level_training_properties else rd.speed_limit
self._target_lane = self.env.ego.APs['lane']
def _init_LTL_preconditions(self):
self._LTL_preconditions.append(
LTLProperty(
"G ( (veh_ahead and before_but_close_to_stop_region) U highest_priority )",
None, not self._enable_low_level_training_properties))
self._LTL_preconditions.append(
LTLProperty("G ( not stopped_now )", self._penalty(self._reward_in_goal),
not self._enable_low_level_training_properties))
self._LTL_preconditions.append(
LTLProperty(
"G ( (lane and target_lane) or (not lane and not target_lane) )",
100, self._enable_low_level_training_properties))
def generate_learning_scenario(self):
self.generate_scenario(
ego_pos_range=self._ego_pos_range,
ego_perturb_lim=(rd.hlanes.width / 4, np.pi / 6),
ego_heading_towards_lane_centre=True)
self.env._terminate_in_goal = False
self.env._reward_in_goal = None
self._reward_in_goal = 200
self._enable_low_level_training_properties = True
def generate_validation_scenario(self):
self._ego_pos_range = (rd.hlanes.start_pos, rd.hlanes.end_pos)
self.generate_learning_scenario()
def _low_level_manual_policy(self):
return self.env.aggressive_driving_policy(EGO_INDEX)
@staticmethod
def _features_dim_reduction(features_tuple):
return extract_ego_features(features_tuple, 'v', 'v_ref', 'e_y', 'psi',
'v tan(psi/L)', 'theta', 'lane', 'acc', 'psi_dot')
@property
def extra_termination_condition(self):
if self._enable_low_level_training_properties: # activated only for the low-level training.
if self.env.ego.APs['stopped_now']:
if self._reward_in_goal is not None:
self._extra_r_terminal = self._reward_in_goal
self._extra_r_terminal *= np.exp(- pow(self.env.ego.theta, 2)
- pow(self.env.ego.y - rd.hlanes.centres[self._target_lane], 2)
- 0.25 * pow(self.env.ego.psi, 2))
else:
self._extra_r_terminal = None
return True
else:
self._extra_r_terminal = None
return False
return False
class Stop(ManeuverBase): class Stop(ManeuverBase):
_terminate_in_goal = True _terminate_in_goal = True
...@@ -79,6 +151,10 @@ class Stop(ManeuverBase): ...@@ -79,6 +151,10 @@ class Stop(ManeuverBase):
# LTLProperty("G ( not has_stopped_in_stop_region )", # LTLProperty("G ( not has_stopped_in_stop_region )",
# self._penalty(self._reward_in_goal))) # self._penalty(self._reward_in_goal)))
self._LTL_preconditions.append(
LTLProperty("G ( not has_stopped_in_stop_region )", -150,
not self._enable_low_level_training_properties))
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty( LTLProperty(
"G ( (before_but_close_to_stop_region or in_stop_region) U has_stopped_in_stop_region )", "G ( (before_but_close_to_stop_region or in_stop_region) U has_stopped_in_stop_region )",
...@@ -98,13 +174,11 @@ class Stop(ManeuverBase): ...@@ -98,13 +174,11 @@ class Stop(ManeuverBase):
def _set_v_ref(self): def _set_v_ref(self):
self._v_ref = rd.speed_limit self._v_ref = rd.speed_limit
#if self._enable_low_level_training_properties:
x = self.env.ego.x x = self.env.ego.x
if x <= rd.hlanes.near_stop_region: if rd.hlanes.near_stop_region < x <= rd.hlanes.stop_region_centre:
self._v_ref = rd.speed_limit self._v_ref = -(rd.speed_limit / abs(rd.hlanes.near_stop_region)) * (x - rd.hlanes.stop_region_centre)
elif x <= rd.hlanes.stop_region_centre: elif x > rd.hlanes.stop_region_centre:
self._v_ref = -(rd.speed_limit / abs(rd.hlanes.near_stop_region)
) * (x - rd.hlanes.stop_region_centre)
else:
self._v_ref = 0 self._v_ref = 0
def generate_learning_scenario(self): def generate_learning_scenario(self):
...@@ -129,18 +203,18 @@ class Stop(ManeuverBase): ...@@ -129,18 +203,18 @@ class Stop(ManeuverBase):
@property @property
def extra_termination_condition(self): def extra_termination_condition(self):
if self.env.ego.APs['has_stopped_in_stop_region']: if self._enable_low_level_training_properties: # activated only for the low-level training.
if self._reward_in_goal is not None: if self.env.ego.APs['has_stopped_in_stop_region']:
self._extra_r_terminal = self._reward_in_goal if self._reward_in_goal is not None:
self._extra_r_terminal *= np.exp(- pow(self.env.ego.theta, 2) self._extra_r_terminal = self._reward_in_goal
- pow(self.env.ego.y - rd.hlanes.centres[self._target_lane], 2) self._extra_r_terminal *= np.exp(- pow(self.env.ego.theta, 2)
- 0.25*pow(self.env.ego.psi, 2)) - pow(self.env.ego.y - rd.hlanes.centres[self._target_lane], 2)
else: - 0.25 * pow(self.env.ego.psi, 2))
self._extra_r_terminal = None else:
return True self._extra_r_terminal = None
return True
elif self._enable_low_level_training_properties: # activated only for the low-level training. elif (rd.speed_limit / 5 < self._v_ref) and \
if (rd.speed_limit / 5 < self._v_ref) and \
(self.env.ego.v < self._v_ref / 2) and self.env.ego.acc < 0: (self.env.ego.v < self._v_ref / 2) and self.env.ego.acc < 0:
self._extra_r_terminal = -100 self._extra_r_terminal = -100
return True return True
...@@ -154,93 +228,91 @@ class Stop(ManeuverBase): ...@@ -154,93 +228,91 @@ class Stop(ManeuverBase):
class Wait(ManeuverBase): class Wait(ManeuverBase):
_terminate_in_goal = True
_reward_in_goal = None _reward_in_goal = None
_terminate_in_goal = True
def _init_LTL_preconditions(self): def _init_LTL_preconditions(self):
self._LTL_preconditions.append(
LTLProperty(
"G ( (in_stop_region and stopped_now) U (highest_priority and intersection_is_clear))", 0,
not self._enable_low_level_training_properties)) # not available in low-level training...
self._LTL_preconditions.append(
LTLProperty("G ( not (in_intersection and highest_priority) )",
self._penalty(self._reward_in_goal)))
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty(
"G ( in_stop_region U (highest_priority and intersection_is_clear) )", 150, self._enable_low_level_training_properties))
self._LTL_preconditions.append( LTLProperty(
LTLProperty( "G ( (in_stop_region and stopped_now) and not (highest_priority and intersection_is_clear))", 0,
"G ( (lane and target_lane) or (not lane and not target_lane) )", not self._enable_low_level_training_properties)) # not available in low-level training...
100, self._enable_low_level_training_properties))
def _init_param(self): #LTLProperty(
self._update_param() # "G ( (in_stop_region and stopped_now) U (highest_priority and intersection_is_clear))", 0,
self._target_lane = self.env.ego.APs['lane'] # not self._enable_low_level_training_properties)) # not available in low-level training...
def _update_param(self): #self._LTL_preconditions.append(
if self.env.ego.APs['highest_priority'] and self.env.ego.APs['intersection_is_clear']: # LTLProperty("G ( not (in_intersection and highest_priority) )",
self._v_ref = rd.speed_limit / 5 # self._penalty(self._reward_in_goal)))
else:
self._v_ref = 0
def generate_learning_scenario(self): #self._LTL_preconditions.append(
# LTLProperty(
# "G ( in_stop_region U (highest_priority and intersection_is_clear) )", 150, self._enable_low_level_training_properties))
n_others = 0 if np.random.rand() <= 0 else np.random.randint(1, 4) #self._LTL_preconditions.append(
# LTLProperty(
# "G ( (lane and target_lane) or (not lane and not target_lane) )",
# 100, self._enable_low_level_training_properties))
self.generate_scenario( def _init_param(self):
enable_LTL_preconditions=True, self._v_ref = 0 #if self._enable_low_level_training_properties else rd.speed_limit
n_others_range=(n_others, n_others), self._target_lane = self.env.ego.APs['lane']
ego_pos_range=rd.hlanes.stop_region,
n_others_stopped_in_stop_region=n_others,
ego_v_upper_lim=0,
ego_perturb_lim=(rd.hlanes.width / 4, np.pi / 6),
ego_heading_towards_lane_centre=True)
max_waited_count = 0 def _low_level_manual_policy(self):
min_waited_count = 1 return (0, 0) # Do nothing during "Wait" but just wait until the highest priority is given.
for veh in self.env.vehs[1:]:
max_waited_count = max(max_waited_count, veh.waited_count)
min_waited_count = min(min_waited_count, veh.waited_count)
min_waited_count = min(min_waited_count, max_waited_count) # @staticmethod
# def _features_dim_reduction(features_tuple):
# return extract_ego_features(
# features_tuple, 'v', 'v_ref', 'e_y', 'psi', 'v tan(psi/L)', 'theta', 'lane', 'acc', 'psi_dot',
# 'pos_stop_region', 'intersection_is_clear', 'highest_priority')
self._extra_action_weights_flag = False
if np.random.rand() <= 0.2: class Left(ManeuverBase):
self.env.ego.waited_count = np.random.randint(0, min_waited_count+1)
else:
self.env.ego.waited_count = np.random.randint(min_waited_count, max_waited_count + 21)
self.env.init_APs(False) min_y_distance = rd.hlanes.width / 4
self._reward_in_goal = 200 _terminate_in_goal = True
self._enable_low_level_training_properties = True _reward_in_goal = None
def generate_validation_scenario(self): def _init_param(self):
super().generate_validation_scenario() self._v_ref = rd.speed_limit
#self._enable_low_level_training_properties = True self._target_lane = False
self._terminate_in_goal = True
@property @property
def extra_termination_condition(self): def goal_achieved(self):
if self._enable_low_level_training_properties: # activated only for the low-level training. ego = self.env.ego
if self.env.ego.APs['highest_priority'] and self.env.ego.APs['intersection_is_clear'] \ APs = self.env.ego.APs
and np.random.rand() <= 0.05 and self.env.ego.v <= self._v_ref / 10: on_other_lane = APs['lane'] == self._target_lane
self._extra_r_terminal = - 100 achieved_y_displacement = np.sign(ego.y) * \
return True (ego.y - rd.hlanes.centres[APs['target_lane']]) >= - self.min_y_distance
else: return on_other_lane and APs['on_route'] and \
self._extra_r_terminal = None achieved_y_displacement and APs['parallel_to_lane']
return False
return False @property
def extra_initiation_condition(self):
return self.env.ego.APs['lane']
@staticmethod @staticmethod
def _features_dim_reduction(features_tuple): def _features_dim_reduction(features_tuple):
return extract_ego_features( return extract_ego_features(features_tuple, 'v', 'v_ref', 'e_y', 'psi',
features_tuple, 'v', 'v_ref', 'e_y', 'psi', 'v tan(psi/L)', 'theta', 'lane', 'acc', 'psi_dot', 'v tan(psi/L)', 'theta', 'lane', 'acc',
'pos_stop_region', 'intersection_is_clear', 'highest_priority') 'psi_dot')
class Right(Left):
def _init_param(self):
self._v_ref = rd.speed_limit
self._target_lane = True
self._terminate_in_goal = True
@property
def extra_initiation_condition(self):
return not self.env.ego.APs['lane']
class ChangeLane(ManeuverBase): class ChangeLane(ManeuverBase):
...@@ -252,12 +324,11 @@ class ChangeLane(ManeuverBase): ...@@ -252,12 +324,11 @@ class ChangeLane(ManeuverBase):
_violation_penalty_in_low_level_training = None _violation_penalty_in_low_level_training = None
high_level_extra_reward = -20 # high_level_extra_reward = -1000000
def _init_param(self): def _init_param(self):
self._v_ref = rd.speed_limit self._v_ref = rd.speed_limit
self._target_lane = not self.env.ego.APs['lane'] self._target_lane = not self.env.ego.APs['lane']
self._terminate_in_goal = True
def _init_LTL_preconditions(self): def _init_LTL_preconditions(self):
self._LTL_preconditions.append( self._LTL_preconditions.append(
...@@ -307,8 +378,6 @@ class ChangeLane(ManeuverBase): ...@@ -307,8 +378,6 @@ class ChangeLane(ManeuverBase):
self._violation_penalty_in_low_level_training = 150 self._violation_penalty_in_low_level_training = 150
self._enable_low_level_training_properties = True self._enable_low_level_training_properties = True
# TODO: It is not a good idea to specify features by numbers, as the list
# of features is ever changing. We should specify them by strings.
@staticmethod @staticmethod
def _features_dim_reduction(features_tuple): def _features_dim_reduction(features_tuple):
return extract_ego_features(features_tuple, 'v', 'v_ref', 'e_y', 'psi', return extract_ego_features(features_tuple, 'v', 'v_ref', 'e_y', 'psi',
...@@ -357,11 +426,22 @@ class Follow(ManeuverBase): ...@@ -357,11 +426,22 @@ class Follow(ManeuverBase):
def generate_validation_scenario(self): def generate_validation_scenario(self):
self.generate_learning_scenario() self.generate_learning_scenario()
def _init_param(self):
self._set_v_ref()
def _update_param(self): def _update_param(self):
self._target_veh_i, _ = self.env.get_V2V_distance() self._set_v_ref()
if self._target_veh_i is not None: def _set_v_ref(self):
self._v_ref = self.env.vehs[self._target_veh_i].v #if self._enable_low_level_training_properties:
self._target_veh_i, _ = self.env.get_V2V_distance()
if self._target_veh_i is not None:
self._v_ref = self.env.vehs[self._target_veh_i].v
else:
self._v_ref = 0
#else:
# self._v_ref = rd.speed_limit
def _low_level_manual_policy(self): def _low_level_manual_policy(self):
return self.env.aggressive_driving_policy(EGO_INDEX) return self.env.aggressive_driving_policy(EGO_INDEX)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment