Commit 6da696a2 authored by Jae Young Lee's avatar Jae Young Lee

Fix the bug of having -100000 or 10000 rewards sometime.

parent 4a9327bd
......@@ -65,7 +65,7 @@ class ControllerBase(PolicyBase):
total_reward += self.current_node.high_level_extra_reward
# TODO for info
return observation, total_reward, self.env.termination_condition, info
return observation, total_reward, terminal, info
# TODO: Looks generic. Move to an intermediate class/highlevel manager so that base class can be clean
''' Executes one step of current node. Sets node_terminal_state_reached flag if node termination condition
......@@ -76,9 +76,7 @@ class ControllerBase(PolicyBase):
def low_level_step_current_node(self):
u_ego = self.current_node.low_level_policy(
self.current_node.get_reduced_features_tuple())
u_ego = self.current_node.low_level_policy(self.current_node.get_reduced_features_tuple())
feature, R, terminal, info = self.current_node.step(u_ego)
self.node_terminal_state_reached = terminal
return self.env.get_features_tuple(
), R, self.env.termination_condition, info
return self.env.get_features_tuple(), R, self.env.termination_condition, info
......@@ -237,7 +237,7 @@ class DQNLearner(LearnerBase):
input_shape: Shape of observation space, e.g (10,);
nb_actions: number of values in action space;
model: Keras Model of actor which takes observation as input and outputs actions. Uses default if not given
policy: KerasRL Policy. Uses default MaxBoltzmannQPolicy if not given
policy: KerasRL Policy. Uses default RestrictedEpsGreedyQPolicy if not given
memory: KerasRL Memory. Uses default SequentialMemory if not given
**kwargs: other optional key-value arguments with defaults defined in property_defaults
"""
......@@ -317,6 +317,7 @@ class DQNLearner(LearnerBase):
nb_steps_warmup=self.nb_steps_warmup,
target_model_update=self.target_model_update,
policy=policy,
test_policy=test_policy,
enable_dueling_network=True)
agent.compile(Adam(lr=self.lr), metrics=['mae'])
......@@ -327,6 +328,8 @@ class DQNLearner(LearnerBase):
env,
nb_steps=1000000,
visualize=False,
verbose=1,
log_interval=10000,
nb_max_episode_steps=200,
tensorboard=False,
model_checkpoints=False,
......@@ -346,7 +349,8 @@ class DQNLearner(LearnerBase):
env,
nb_steps=nb_steps,
visualize=visualize,
verbose=1,
verbose=verbose,
log_interval=log_interval,
nb_max_episode_steps=nb_max_episode_steps,
callbacks=callbacks)
......@@ -448,9 +452,11 @@ class RestrictedEpsGreedyQPolicy(EpsGreedyQPolicy):
# every q_value is -np.inf (this sometimes inevitably happens within the fit and test functions
# of kerasrl at the terminal stage as they force to call forward in Kerasrl-learner which calls this function.
# In this case, we choose a policy randomly.
# TODO: exception process or some more process to choose action in this exceptional case.
if len(index) < 1:
action = np.random.random_integers(0, nb_actions - 1)
# every q_value is -np.inf, we choose action = 0
action = 0
print("Warning: no action satisfies initiation condition, action = 0 is chosen by default.")
elif np.random.uniform() <= self.eps:
action = index[np.random.random_integers(0, len(index) - 1)]
......@@ -479,21 +485,15 @@ class RestrictedGreedyQPolicy(GreedyQPolicy):
Selection action
"""
assert q_values.ndim == 1
nb_actions = q_values.shape[0]
restricted_q_values = list()
for i in range(0, nb_actions):
if q_values[i] != -np.inf:
restricted_q_values.append(q_values[i])
# every q_value is -np.inf (this sometimes inevitably happens within the fit and test functions
# of kerasrl at the terminal stage as they force to call forward in Kerasrl-learner which calls this function.
# In this case, we choose a policy randomly.
if len(restricted_q_values) < 1:
action = np.random.random_integers(0, nb_actions - 1)
# TODO: exception process or some more process to choose action in this exceptional case.
if np.max(q_values) == - np.inf:
# every q_value is -np.inf, we choose action = 0
action = 0
print("Warning: no action satisfies initiation condition, action = 0 is chosen by default.")
else:
action = np.argmax(restricted_q_values)
action = np.argmax(q_values)
return action
......@@ -521,6 +521,7 @@ class DQNAgentOverOptions(DQNAgent):
def __get_invalid_node_indices(self):
"""Returns a list of option indices that are invalid according to
initiation conditions."""
invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases):
# TODO: Locate reset_maneuver to another place as this is a "get" function.
......
......@@ -3,7 +3,7 @@
"wait": "Wait",
"follow": "Follow",
"stop": "Stop",
"immediatestop": "ImmediateStop",
"changelane": "ChangeLane",
"keeplane": "KeepLane"
},
......
......@@ -157,7 +157,7 @@ class EpisodicEnvBase(GymCompliantEnvBase):
if self._terminate_in_goal and self.goal_achieved:
return True
return self.violation_happened and self._LTL_preconditions_enable
return self.violation_happened
@property
def goal_achieved(self):
......
......@@ -93,7 +93,7 @@ vlanes = Route(
[-vwidth - 5.0 - intersection_voffset, -vwidth - intersection_voffset], 35,
[-4.0, 4.0])
intersection_width = vlanes.n_lanes * vlanes.width
intersection_width = vlanes.n_lanes * vlanes.width
intersection_height = hlanes.n_lanes * hlanes.width
intersection_width_w_offset = intersection_width + 2 * intersection_hoffset
......
......@@ -51,7 +51,10 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
#: The weight vector to calculate the cost. In the maneuver, cost_weights
# can be set to a specific value which may be different than the default.
cost_weights = (1.0, 0.25, 0.1, 1.0, 100.0, 0.1, 0.25, 0.1)
# TODO: check _cost_weights in both here and ManeuverBase. The _cost_weights has to be substituted to here, but it doesn't sometimes.
# TODO: set a functionality of setting _cost_weights for low and high level training separately.
_cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
#TODO: Move this to constants
# The empirical min and max of each term in the cost vector, which is used to normalize the values
......@@ -271,7 +274,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
# stopped_car_scenario = bool(np.random.randint(0, 1)) TODO: this scenario may not work
n_others_stopped_in_stop_region = np.random.randint(
0, min(3, n_others - stopped_car_scenario))
veh_ahead_scenario = bool(np.random.randint(0, 1)) or veh_ahead_scenario
veh_ahead_scenario = bool(np.random.randint(0, 2)) or veh_ahead_scenario
if n_others_stopped_in_stop_region > min(
n_others - stopped_car_scenario, 3):
......
......@@ -180,11 +180,6 @@ class OptionsGraph:
print("\n Warning: the trained low-level policy of \"" + key +
"\" does not exists; the manual policy will be used.\n")
# setting the cost weights for high-level policy training.
# TODO: this shouldn't be initialized here, but within ManeuverBase class (e.g. make some flag indicating high-level training and set the weights)...
maneuver._cost_weights = (100.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.25 * 1e-3, 0.1 * 1e-3)
if self.config["method"] == "mcts":
maneuver.timeout = np.inf
......
......@@ -29,8 +29,9 @@ class ManeuverBase(EpisodicEnvBase):
# _extra_action_weights_flag = True); note that a cost is defined
# as a negative reward, so a cost will be summed up to the reward
# with subtraction.
# TODO: remove or to provide additional functionality, keep _cost_weights in ManeuverBase here (see other TODOs in simple_intersection_env regarding _cost_weights).
_cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.25 * 1e-3, 0.1 * 1e-3)
100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
_extra_r_terminal = None
_extra_r_on_timeout = None
......@@ -38,7 +39,7 @@ class ManeuverBase(EpisodicEnvBase):
#: the flag being False when _cost_weights is used without
# modification; If True, then the action parts of _cost_weights
# are increased for some edge cases (see the step method).
_extra_action_weights_flag = True
_extra_action_weights_flag = False
#: the extra weights on the actions added to _cost_weights
# for some edge cases when _extra_action_weights_flag = True.
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment