Skip to content
Snippets Groups Projects
Commit 77d61b11 authored by Jae Young Lee's avatar Jae Young Lee
Browse files

Merge branch 'master' into retraining_wait_maneuver

parents ee257c39 062ad4ff
No related branches found
No related tags found
No related merge requests found
File added
File added
...@@ -53,9 +53,8 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -53,9 +53,8 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
# can be set to a specific value which may be different than the default. # can be set to a specific value which may be different than the default.
# TODO: check _cost_weights in both here and ManeuverBase. The _cost_weights has to be substituted to here, but it doesn't sometimes. # TODO: check _cost_weights in both here and ManeuverBase. The _cost_weights has to be substituted to here, but it doesn't sometimes.
# TODO: set a functionality of setting _cost_weights for low and high level training separately. # TODO: set a functionality of setting _cost_weights for low and high level training separately.
_cost_weights = (0, 0, 0, 0, 0, 0, 0, 0) _cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
#(10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3, 100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
#100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
#TODO: Move this to constants #TODO: Move this to constants
# The empirical min and max of each term in the cost vector, which is used to normalize the values # The empirical min and max of each term in the cost vector, which is used to normalize the values
......
...@@ -47,7 +47,8 @@ def high_level_policy_training(nb_steps=25000, ...@@ -47,7 +47,8 @@ def high_level_policy_training(nb_steps=25000,
nb_actions=options.get_number_of_nodes(), nb_actions=options.get_number_of_nodes(),
target_model_update=1e-3, target_model_update=1e-3,
delta_clip=100, delta_clip=100,
low_level_policies=options.maneuvers) low_level_policies=options.maneuvers,
gamma=1)
if load_weights: if load_weights:
agent.load_model(save_path) agent.load_model(save_path)
...@@ -77,7 +78,8 @@ def high_level_policy_testing(nb_episodes_for_test=100, ...@@ -77,7 +78,8 @@ def high_level_policy_testing(nb_episodes_for_test=100,
agent = DQNLearner( agent = DQNLearner(
input_shape=(50, ), input_shape=(50, ),
nb_actions=options.get_number_of_nodes(), nb_actions=options.get_number_of_nodes(),
low_level_policies=options.maneuvers) low_level_policies=options.maneuvers,
gamma=1)
if pretrained: if pretrained:
trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file
...@@ -99,7 +101,8 @@ def evaluate_high_level_policy(nb_episodes_for_test=100, ...@@ -99,7 +101,8 @@ def evaluate_high_level_policy(nb_episodes_for_test=100,
agent = DQNLearner( agent = DQNLearner(
input_shape=(50, ), input_shape=(50, ),
nb_actions=options.get_number_of_nodes(), nb_actions=options.get_number_of_nodes(),
low_level_policies=options.maneuvers) low_level_policies=options.maneuvers,
gamma=1)
if pretrained: if pretrained:
trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file
......
File added
...@@ -30,9 +30,8 @@ class ManeuverBase(EpisodicEnvBase): ...@@ -30,9 +30,8 @@ class ManeuverBase(EpisodicEnvBase):
# as a negative reward, so a cost will be summed up to the reward # as a negative reward, so a cost will be summed up to the reward
# with subtraction. # with subtraction.
# TODO: remove or to provide additional functionality, keep _cost_weights in ManeuverBase here (see other TODOs in simple_intersection_env regarding _cost_weights). # TODO: remove or to provide additional functionality, keep _cost_weights in ManeuverBase here (see other TODOs in simple_intersection_env regarding _cost_weights).
_cost_weights = (0, 0, 0, 0, 0, 0, 0, 0) _cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
#(10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3, 100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
#100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
_extra_r_terminal = None _extra_r_terminal = None
_extra_r_on_timeout = None _extra_r_on_timeout = None
......
...@@ -243,34 +243,81 @@ class Wait(ManeuverBase): ...@@ -243,34 +243,81 @@ class Wait(ManeuverBase):
def _init_LTL_preconditions(self): def _init_LTL_preconditions(self):
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty("G ( (in_stop_region and stopped_now) and not (highest_priority and intersection_is_clear))", LTLProperty("G ( (in_stop_region and stopped_now) U (highest_priority and intersection_is_clear))",
None, not self._enable_low_level_training_properties)) # not available in low-level training... None, not self._enable_low_level_training_properties)) # not available in low-level training...
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty("G ( not (in_intersection and highest_priority) )", LTLProperty("G ( not (in_intersection and highest_priority and intersection_is_clear) )",
self._penalty(self._reward_in_goal))) self._penalty(self._reward_in_goal)))
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty( LTLProperty(
"G ( in_stop_region U (highest_priority and intersection_is_clear) )", 150, self._enable_low_level_training_properties)) "G ( in_stop_region U (highest_priority and intersection_is_clear) )", 150,
self._enable_low_level_training_properties))
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty( LTLProperty(
"G ( (lane and target_lane) or (not lane and not target_lane) )", "G ( (lane and target_lane) or (not lane and not target_lane) )",
100, self._enable_low_level_training_properties)) 150, self._enable_low_level_training_properties))
def _init_param(self): def _init_param(self):
self._v_ref = 0 #if self._enable_low_level_training_properties else rd.speed_limit self._update_param()
self._target_lane = self.env.ego.APs['lane'] self._target_lane = self.env.ego.APs['lane']
def _low_level_manual_policy(self): def _update_param(self):
return (0, 0) # Do nothing during "Wait" but just wait until the highest priority is given. if self.env.ego.APs['highest_priority'] and self.env.ego.APs['intersection_is_clear']:
self._v_ref = rd.speed_limit
else:
self._v_ref = 0
# @staticmethod def generate_learning_scenario(self):
# def _features_dim_reduction(features_tuple): n_others = 0 if np.random.rand() <= 0 else np.random.randint(1, 4)
# return extract_ego_features( self.generate_scenario(
# features_tuple, 'v', 'v_ref', 'e_y', 'psi', 'v tan(psi/L)', 'theta', 'lane', 'acc', 'psi_dot', n_others_range=(n_others, n_others),
# 'pos_stop_region', 'intersection_is_clear', 'highest_priority') ego_pos_range=rd.hlanes.stop_region,
n_others_stopped_in_stop_region=n_others,
ego_v_upper_lim=0,
ego_perturb_lim=(rd.hlanes.width / 4, np.pi / 6),
ego_heading_towards_lane_centre=True)
max_waited_count = 0
min_waited_count = 1
for veh in self.env.vehs[1:]:
max_waited_count = max(max_waited_count, veh.waited_count)
min_waited_count = min(min_waited_count, veh.waited_count)
min_waited_count = min(min_waited_count, max_waited_count)
self._extra_action_weights_flag = False
if np.random.rand() <= 0.5:
self.env.ego.waited_count = np.random.randint(0, min_waited_count+1)
else:
self.env.ego.waited_count = np.random.randint(min_waited_count, max_waited_count + 21)
self.env.init_APs(False)
self._reward_in_goal = 200
self._enable_low_level_training_properties = True
self._extra_action_weights_flag = True
@property
def extra_termination_condition(self):
if self._enable_low_level_training_properties: # activated only for the low-level training.
if self.env.ego.APs['highest_priority'] and self.env.ego.APs['intersection_is_clear'] \
and np.random.rand() <= 0.1 and self.env.ego.v <= self._v_ref / 10 \
and self.env.ego.acc < 0:
self._extra_r_terminal = - 100
return True
else:
self._extra_r_terminal = None
return False
return False
@staticmethod
def _features_dim_reduction(features_tuple):
return extract_ego_features(
features_tuple, 'v', 'v_ref', 'e_y', 'psi', 'v tan(psi/L)', 'theta', 'lane', 'acc', 'psi_dot',
'pos_stop_region', 'intersection_is_clear', 'highest_priority')
class Left(ManeuverBase): class Left(ManeuverBase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment