Commit c0035ce5 authored by Jae Young Lee's avatar Jae Young Lee

Merge branch 'master' into improving_and_refactoring

parents 35249ff6 895596a8
...@@ -237,7 +237,7 @@ class DQNLearner(LearnerBase): ...@@ -237,7 +237,7 @@ class DQNLearner(LearnerBase):
input_shape: Shape of observation space, e.g (10,); input_shape: Shape of observation space, e.g (10,);
nb_actions: number of values in action space; nb_actions: number of values in action space;
model: Keras Model of actor which takes observation as input and outputs actions. Uses default if not given model: Keras Model of actor which takes observation as input and outputs actions. Uses default if not given
policy: KerasRL Policy. Uses default MaxBoltzmannQPolicy if not given policy: KerasRL Policy. Uses default RestrictedEpsGreedyQPolicy if not given
memory: KerasRL Memory. Uses default SequentialMemory if not given memory: KerasRL Memory. Uses default SequentialMemory if not given
**kwargs: other optional key-value arguments with defaults defined in property_defaults **kwargs: other optional key-value arguments with defaults defined in property_defaults
""" """
...@@ -452,9 +452,11 @@ class RestrictedEpsGreedyQPolicy(EpsGreedyQPolicy): ...@@ -452,9 +452,11 @@ class RestrictedEpsGreedyQPolicy(EpsGreedyQPolicy):
# every q_value is -np.inf (this sometimes inevitably happens within the fit and test functions # every q_value is -np.inf (this sometimes inevitably happens within the fit and test functions
# of kerasrl at the terminal stage as they force to call forward in Kerasrl-learner which calls this function. # of kerasrl at the terminal stage as they force to call forward in Kerasrl-learner which calls this function.
# In this case, we choose a policy randomly. # TODO: exception process or some more process to choose action in this exceptional case.
if len(index) < 1: if len(index) < 1:
action = np.random.random_integers(0, nb_actions - 1) # every q_value is -np.inf, we choose action = 0
action = 0
print("Warning: no action satisfies initiation condition, action = 0 is chosen by default.")
elif np.random.uniform() <= self.eps: elif np.random.uniform() <= self.eps:
action = index[np.random.random_integers(0, len(index) - 1)] action = index[np.random.random_integers(0, len(index) - 1)]
...@@ -484,9 +486,11 @@ class RestrictedGreedyQPolicy(GreedyQPolicy): ...@@ -484,9 +486,11 @@ class RestrictedGreedyQPolicy(GreedyQPolicy):
""" """
assert q_values.ndim == 1 assert q_values.ndim == 1
# TODO: exception process or some more process to choose action in this exceptional case.
if np.max(q_values) == - np.inf: if np.max(q_values) == - np.inf:
# every q_value is -np.inf, we choose action = 0 # every q_value is -np.inf, we choose action = 0
action = 0 action = 0
print("Warning: no action satisfies initiation condition, action = 0 is chosen by default.")
else: else:
action = np.argmax(q_values) action = np.argmax(q_values)
...@@ -551,4 +555,4 @@ class DQNAgentOverOptions(DQNAgent): ...@@ -551,4 +555,4 @@ class DQNAgentOverOptions(DQNAgent):
for node_index in invalid_node_indices: for node_index in invalid_node_indices:
q_values[node_index] = -np.inf q_values[node_index] = -np.inf
return q_values return q_values
\ No newline at end of file
...@@ -52,6 +52,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -52,6 +52,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
#: The weight vector to calculate the cost. In the maneuver, cost_weights #: The weight vector to calculate the cost. In the maneuver, cost_weights
# can be set to a specific value which may be different than the default. # can be set to a specific value which may be different than the default.
# TODO: check _cost_weights in both here and ManeuverBase. The _cost_weights has to be substituted to here, but it doesn't sometimes. # TODO: check _cost_weights in both here and ManeuverBase. The _cost_weights has to be substituted to here, but it doesn't sometimes.
# TODO: set a functionality of setting _cost_weights for low and high level training separately.
_cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3, _cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3) 100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
...@@ -1292,4 +1293,4 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1292,4 +1293,4 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
if veh.waited_count > self.ego.waited_count: if veh.waited_count > self.ego.waited_count:
n_others_with_higher_priority += 1 n_others_with_higher_priority += 1
return n_others_with_higher_priority return n_others_with_higher_priority
\ No newline at end of file
...@@ -29,6 +29,7 @@ class ManeuverBase(EpisodicEnvBase): ...@@ -29,6 +29,7 @@ class ManeuverBase(EpisodicEnvBase):
# _extra_action_weights_flag = True); note that a cost is defined # _extra_action_weights_flag = True); note that a cost is defined
# as a negative reward, so a cost will be summed up to the reward # as a negative reward, so a cost will be summed up to the reward
# with subtraction. # with subtraction.
# TODO: remove or to provide additional functionality, keep _cost_weights in ManeuverBase here (see other TODOs in simple_intersection_env regarding _cost_weights).
_cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3, _cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3) 100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
......
...@@ -152,6 +152,7 @@ class Stop(ManeuverBase): ...@@ -152,6 +152,7 @@ class Stop(ManeuverBase):
LTLProperty("G ( not has_stopped_in_stop_region )", LTLProperty("G ( not has_stopped_in_stop_region )",
self._penalty(self._reward_in_goal), not self._enable_low_level_training_properties)) self._penalty(self._reward_in_goal), not self._enable_low_level_training_properties))
# before_intersection rather than "before_but_close_to_stop_region or in_stop_region"?
self._LTL_preconditions.append( self._LTL_preconditions.append(
LTLProperty( LTLProperty(
"G ( (before_but_close_to_stop_region or in_stop_region) U has_stopped_in_stop_region )", "G ( (before_but_close_to_stop_region or in_stop_region) U has_stopped_in_stop_region )",
...@@ -492,4 +493,4 @@ class Follow(ManeuverBase): ...@@ -492,4 +493,4 @@ class Follow(ManeuverBase):
return ego_features + extract_other_veh_features( return ego_features + extract_other_veh_features(
features_tuple, self._target_veh_i, 'rel_x', 'rel_y', 'v', 'acc') features_tuple, self._target_veh_i, 'rel_x', 'rel_y', 'v', 'acc')
else: else:
return ego_features + (0.0, 0.0, 0.0, 0.0) return ego_features + (0.0, 0.0, 0.0, 0.0)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment