Skip to content
Snippets Groups Projects
Commit c0035ce5 authored by Jae Young Lee's avatar Jae Young Lee
Browse files

Merge branch 'master' into improving_and_refactoring

parents 35249ff6 895596a8
No related branches found
No related tags found
No related merge requests found
......@@ -237,7 +237,7 @@ class DQNLearner(LearnerBase):
input_shape: Shape of observation space, e.g (10,);
nb_actions: number of values in action space;
model: Keras Model of actor which takes observation as input and outputs actions. Uses default if not given
policy: KerasRL Policy. Uses default MaxBoltzmannQPolicy if not given
policy: KerasRL Policy. Uses default RestrictedEpsGreedyQPolicy if not given
memory: KerasRL Memory. Uses default SequentialMemory if not given
**kwargs: other optional key-value arguments with defaults defined in property_defaults
"""
......@@ -452,9 +452,11 @@ class RestrictedEpsGreedyQPolicy(EpsGreedyQPolicy):
# every q_value is -np.inf (this sometimes inevitably happens within the fit and test functions
# of kerasrl at the terminal stage as they force to call forward in Kerasrl-learner which calls this function.
# In this case, we choose a policy randomly.
# TODO: exception process or some more process to choose action in this exceptional case.
if len(index) < 1:
action = np.random.random_integers(0, nb_actions - 1)
# every q_value is -np.inf, we choose action = 0
action = 0
print("Warning: no action satisfies initiation condition, action = 0 is chosen by default.")
elif np.random.uniform() <= self.eps:
action = index[np.random.random_integers(0, len(index) - 1)]
......@@ -484,9 +486,11 @@ class RestrictedGreedyQPolicy(GreedyQPolicy):
"""
assert q_values.ndim == 1
# TODO: exception process or some more process to choose action in this exceptional case.
if np.max(q_values) == - np.inf:
# every q_value is -np.inf, we choose action = 0
action = 0
print("Warning: no action satisfies initiation condition, action = 0 is chosen by default.")
else:
action = np.argmax(q_values)
......@@ -551,4 +555,4 @@ class DQNAgentOverOptions(DQNAgent):
for node_index in invalid_node_indices:
q_values[node_index] = -np.inf
return q_values
return q_values
\ No newline at end of file
......@@ -52,6 +52,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
#: The weight vector to calculate the cost. In the maneuver, cost_weights
# can be set to a specific value which may be different than the default.
# TODO: check _cost_weights in both here and ManeuverBase. The _cost_weights has to be substituted to here, but it doesn't sometimes.
# TODO: set a functionality of setting _cost_weights for low and high level training separately.
_cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
......@@ -1292,4 +1293,4 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
if veh.waited_count > self.ego.waited_count:
n_others_with_higher_priority += 1
return n_others_with_higher_priority
return n_others_with_higher_priority
\ No newline at end of file
......@@ -29,6 +29,7 @@ class ManeuverBase(EpisodicEnvBase):
# _extra_action_weights_flag = True); note that a cost is defined
# as a negative reward, so a cost will be summed up to the reward
# with subtraction.
# TODO: remove or to provide additional functionality, keep _cost_weights in ManeuverBase here (see other TODOs in simple_intersection_env regarding _cost_weights).
_cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
......
......@@ -152,6 +152,7 @@ class Stop(ManeuverBase):
LTLProperty("G ( not has_stopped_in_stop_region )",
self._penalty(self._reward_in_goal), not self._enable_low_level_training_properties))
# before_intersection rather than "before_but_close_to_stop_region or in_stop_region"?
self._LTL_preconditions.append(
LTLProperty(
"G ( (before_but_close_to_stop_region or in_stop_region) U has_stopped_in_stop_region )",
......@@ -492,4 +493,4 @@ class Follow(ManeuverBase):
return ego_features + extract_other_veh_features(
features_tuple, self._target_veh_i, 'rel_x', 'rel_y', 'v', 'acc')
else:
return ego_features + (0.0, 0.0, 0.0, 0.0)
return ego_features + (0.0, 0.0, 0.0, 0.0)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment