Commit ea8874a8 authored by Jae Young Lee's avatar Jae Young Lee

Merge branch 'improve_and_bugfix_low_and_high_level_training' into 'master'

Improve and bugfix low and high level training

See merge request !1
parents f2171d2c 6da696a2
......@@ -51,21 +51,21 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info
'''
# TODO: this is never called when you test high-level policy rather than train...
def step_current_node(self, visualize_low_level_steps=False):
total_reward = 0
self.node_terminal_state_reached = False
while not self.node_terminal_state_reached:
observation, reward, terminal, info = self.low_level_step_current_node(
)
observation, reward, terminal, info = self.low_level_step_current_node()
if visualize_low_level_steps:
self.env.render()
# TODO: make the total_reward discounted....
total_reward += reward
total_reward += self.current_node.high_level_extra_reward
# TODO for info
return observation, total_reward, self.env.termination_condition, info
return observation, total_reward, terminal, info
# TODO: Looks generic. Move to an intermediate class/highlevel manager so that base class can be clean
''' Executes one step of current node. Sets node_terminal_state_reached flag if node termination condition
......@@ -76,9 +76,7 @@ class ControllerBase(PolicyBase):
def low_level_step_current_node(self):
u_ego = self.current_node.low_level_policy(
self.current_node.get_reduced_features_tuple())
u_ego = self.current_node.low_level_policy(self.current_node.get_reduced_features_tuple())
feature, R, terminal, info = self.current_node.step(u_ego)
self.node_terminal_state_reached = terminal
return self.env.get_features_tuple(
), R, self.env.termination_condition, info
return self.env.get_features_tuple(), R, self.env.termination_condition, info
......@@ -8,7 +8,7 @@ from keras.callbacks import TensorBoard
from rl.agents import DDPGAgent, DQNAgent
from rl.memory import SequentialMemory
from rl.random import OrnsteinUhlenbeckProcess
from rl.policy import BoltzmannQPolicy, MaxBoltzmannQPolicy
from rl.policy import GreedyQPolicy, EpsGreedyQPolicy, MaxBoltzmannQPolicy
from rl.callbacks import ModelIntervalCheckpoint
......@@ -229,6 +229,7 @@ class DQNLearner(LearnerBase):
model=None,
policy=None,
memory=None,
test_policy=None,
**kwargs):
"""The constructor which sets the properties of the class.
......@@ -236,8 +237,8 @@ class DQNLearner(LearnerBase):
input_shape: Shape of observation space, e.g (10,);
nb_actions: number of values in action space;
model: Keras Model of actor which takes observation as input and outputs actions. Uses default if not given
policy: KerasRL Policy. Uses default SequentialMemory if not given
memory: KerasRL Memory. Uses default BoltzmannQPolicy if not given
policy: KerasRL Policy. Uses default RestrictedEpsGreedyQPolicy if not given
memory: KerasRL Memory. Uses default SequentialMemory if not given
**kwargs: other optional key-value arguments with defaults defined in property_defaults
"""
super(DQNLearner, self).__init__(input_shape, nb_actions, **kwargs)
......@@ -255,12 +256,14 @@ class DQNLearner(LearnerBase):
model = self.get_default_model()
if policy is None:
policy = self.get_default_policy()
if test_policy is None:
test_policy = self.get_default_test_policy()
if memory is None:
memory = self.get_default_memory()
self.low_level_policies = low_level_policies
self.agent_model = self.create_agent(model, policy, memory)
self.agent_model = self.create_agent(model, policy, memory, test_policy)
def get_default_model(self):
"""Creates the default model.
......@@ -269,9 +272,11 @@ class DQNLearner(LearnerBase):
"""
model = Sequential()
model.add(Flatten(input_shape=(1, ) + self.input_shape))
model.add(Dense(32))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(32))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(self.nb_actions))
model.add(Activation('linear'))
......@@ -280,7 +285,10 @@ class DQNLearner(LearnerBase):
return model
def get_default_policy(self):
return MaxBoltzmannQPolicy(eps=0.3)
return RestrictedEpsGreedyQPolicy(0.3)
def get_default_test_policy(self):
return RestrictedGreedyQPolicy()
def get_default_memory(self):
"""Creates the default memory model.
......@@ -291,7 +299,7 @@ class DQNLearner(LearnerBase):
limit=self.mem_size, window_length=self.mem_window_length)
return memory
def create_agent(self, model, policy, memory):
def create_agent(self, model, policy, memory, test_policy):
"""Creates a KerasRL DDPGAgent with given components.
Args:
......@@ -309,6 +317,7 @@ class DQNLearner(LearnerBase):
nb_steps_warmup=self.nb_steps_warmup,
target_model_update=self.target_model_update,
policy=policy,
test_policy=test_policy,
enable_dueling_network=True)
agent.compile(Adam(lr=self.lr), metrics=['mae'])
......@@ -319,6 +328,8 @@ class DQNLearner(LearnerBase):
env,
nb_steps=1000000,
visualize=False,
verbose=1,
log_interval=10000,
nb_max_episode_steps=200,
tensorboard=False,
model_checkpoints=False,
......@@ -338,7 +349,8 @@ class DQNLearner(LearnerBase):
env,
nb_steps=nb_steps,
visualize=visualize,
verbose=1,
verbose=verbose,
log_interval=log_interval,
nb_max_episode_steps=nb_max_episode_steps,
callbacks=callbacks)
......@@ -410,6 +422,82 @@ class DQNLearner(LearnerBase):
return relevant
class RestrictedEpsGreedyQPolicy(EpsGreedyQPolicy):
"""Implement the epsilon greedy policy
Restricted Eps Greedy policy.
This policy ensures that it never chooses the action whose value is -inf
"""
def __init__(self, eps=.1):
super(RestrictedEpsGreedyQPolicy, self).__init__(eps)
def select_action(self, q_values):
"""Return the selected action
# Arguments
q_values (np.ndarray): List of the estimations of Q for each action
# Returns
Selection action
"""
assert q_values.ndim == 1
nb_actions = q_values.shape[0]
index = list()
for i in range(0, nb_actions):
if q_values[i] != -np.inf:
index.append(i)
# every q_value is -np.inf (this sometimes inevitably happens within the fit and test functions
# of kerasrl at the terminal stage as they force to call forward in Kerasrl-learner which calls this function.
# TODO: exception process or some more process to choose action in this exceptional case.
if len(index) < 1:
# every q_value is -np.inf, we choose action = 0
action = 0
print("Warning: no action satisfies initiation condition, action = 0 is chosen by default.")
elif np.random.uniform() <= self.eps:
action = index[np.random.random_integers(0, len(index) - 1)]
else:
action = np.argmax(q_values)
return action
class RestrictedGreedyQPolicy(GreedyQPolicy):
"""Implement the epsilon greedy policy
Restricted Greedy policy.
This policy ensures that it never chooses the action whose value is -inf
"""
def select_action(self, q_values):
"""Return the selected action
# Arguments
q_values (np.ndarray): List of the estimations of Q for each action
# Returns
Selection action
"""
assert q_values.ndim == 1
# TODO: exception process or some more process to choose action in this exceptional case.
if np.max(q_values) == - np.inf:
# every q_value is -np.inf, we choose action = 0
action = 0
print("Warning: no action satisfies initiation condition, action = 0 is chosen by default.")
else:
action = np.argmax(q_values)
return action
class DQNAgentOverOptions(DQNAgent):
def __init__(self,
model,
......@@ -433,8 +521,10 @@ class DQNAgentOverOptions(DQNAgent):
def __get_invalid_node_indices(self):
"""Returns a list of option indices that are invalid according to
initiation conditions."""
invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases):
# TODO: Locate reset_maneuver to another place as this is a "get" function.
self.low_level_policies[option_alias].reset_maneuver()
if not self.low_level_policies[option_alias].initiation_condition:
invalid_node_indices.append(index)
......
......@@ -30,6 +30,7 @@ class ManualPolicy(ControllerBase):
new_node = None
if self.low_level_policies[self.current_node].termination_condition:
for next_node in self.adj[self.current_node]:
self.low_level_policies[next_node].reset_maneuver()
if self.low_level_policies[next_node].initiation_condition:
new_node = next_node
break # change current_node to the highest priority next node
......
......@@ -25,7 +25,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
# three types possible ('min', 'max', or 'sum');
# See _reward_superposition below.
terminal_reward_type = 'max'
# TODO: consider the case, where every terminal reward is None. Make this class have a default terminal value (not None) and use it in this case.
terminal_reward_type = 'min'
#: If true, the maneuver terminates when the goal has been achieved.
_terminate_in_goal = False
......@@ -140,13 +141,11 @@ class EpisodicEnvBase(GymCompliantEnvBase):
def _reset_model_checker(self, AP):
self.__mc_AP = int(AP)
if self._LTL_preconditions_enable:
for LTL_precondition in self._LTL_preconditions:
LTL_precondition.reset_property()
if LTL_precondition.enabled:
LTL_precondition.check_incremental(self.__mc_AP)
self._incremental_model_checking(AP)
def _set_mc_AP(self, AP):
self.__mc_AP = int(AP)
......@@ -158,7 +157,7 @@ class EpisodicEnvBase(GymCompliantEnvBase):
if self._terminate_in_goal and self.goal_achieved:
return True
return self.violation_happened and self._LTL_preconditions_enable
return self.violation_happened
@property
def goal_achieved(self):
......@@ -176,8 +175,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
if not self._LTL_preconditions_enable:
return False
for LTL_precondition in self._LTL_preconditions:
if LTL_precondition.result == Parser.FALSE:
for LTL in self._LTL_preconditions:
if LTL.enabled and (LTL.result == Parser.FALSE):
return True
return False
......
......@@ -174,6 +174,6 @@ class Features(object):
# Add buffer features to make a fixed length feature vector
for i in range(MAX_NUM_VEHICLES - len(self.other_vehs)):
feature += (0.0, 0.0, 0.0, 0.0, -1)
feature += (0.0, 0.0, 0.0, 0.0, -1.0)
return feature
......@@ -93,7 +93,7 @@ vlanes = Route(
[-vwidth - 5.0 - intersection_voffset, -vwidth - intersection_voffset], 35,
[-4.0, 4.0])
intersection_width = vlanes.n_lanes * vlanes.width
intersection_width = vlanes.n_lanes * vlanes.width
intersection_height = hlanes.n_lanes * hlanes.width
intersection_width_w_offset = intersection_width + 2 * intersection_hoffset
......
......@@ -51,7 +51,10 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
#: The weight vector to calculate the cost. In the maneuver, cost_weights
# can be set to a specific value which may be different than the default.
cost_weights = (1.0, 0.25, 0.1, 1.0, 100.0, 0.1, 0.25, 0.1)
# TODO: check _cost_weights in both here and ManeuverBase. The _cost_weights has to be substituted to here, but it doesn't sometimes.
# TODO: set a functionality of setting _cost_weights for low and high level training separately.
_cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
#TODO: Move this to constants
# The empirical min and max of each term in the cost vector, which is used to normalize the values
......@@ -271,7 +274,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
# stopped_car_scenario = bool(np.random.randint(0, 1)) TODO: this scenario may not work
n_others_stopped_in_stop_region = np.random.randint(
0, min(3, n_others - stopped_car_scenario))
veh_ahead_scenario = bool(np.random.randint(0, 1))
veh_ahead_scenario = bool(np.random.randint(0, 2)) or veh_ahead_scenario
if n_others_stopped_in_stop_region > min(
n_others - stopped_car_scenario, 3):
......@@ -1156,12 +1159,13 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
Returns True if the environment has terminated
"""
model_checks_violated = (self._LTL_preconditions_enable and \
self.current_model_checking_result())
model_checks_violated = self._LTL_preconditions_enable and \
self.current_model_checking_result()
reached_goal = self._terminate_in_goal and self.goal_achieved
self._check_collisions()
self._check_ego_theta_out_of_range()
terminated = self.termination_condition
return model_checks_violated or reached_goal or terminated
@property
......@@ -1181,7 +1185,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return (self.ego.x >= rd.hlanes.end_pos) and \
not self.collision_happened and \
not self.ego.APs['over_speed_limit']
(self.ego.v <= 1.1*rd.speed_limit)
def reset(self):
"""Gym compliant reset function.
......@@ -1229,7 +1233,6 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
self.window.dispatch_events()
# Text information about ego vehicle's states
# Right now, we are only training one option (Stop)
info = "Ego Attributes:" + get_APs(
self, EGO_INDEX, 'in_stop_region',
'has_entered_stop_region', 'has_stopped_in_stop_region',
......
......@@ -11,7 +11,7 @@ def high_level_policy_training(nb_steps=25000,
load_weights=False,
training=True,
testing=True,
nb_episodes_for_test=10,
nb_episodes_for_test=20,
max_nb_steps=100,
visualize=False,
tensorboard=False,
......@@ -63,8 +63,7 @@ def high_level_policy_training(nb_steps=25000,
agent.save_model(save_path)
if testing:
options.set_controller_policy(agent.predict)
agent.test_model(options, nb_episodes=nb_episodes_for_test)
high_level_policy_testing(nb_episodes_for_test=nb_episodes_for_test)
return agent
......@@ -228,7 +227,6 @@ if __name__ == "__main__":
load_weights=args.load_weights,
save_path=args.save_file,
tensorboard=args.tensorboard,
nb_episodes_for_test=20,
visualize=args.visualize)
if args.test:
......
import json
import os # for the use of os.path.isfile
from .simple_intersection.maneuvers import *
from .simple_intersection.mcts_maneuvers import *
from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSController, ManualPolicy
......@@ -155,19 +156,29 @@ class OptionsGraph:
# TODO: error handling
def load_trained_low_level_policies(self):
for key, maneuver in self.maneuvers.items():
agent = DDPGLearner(
input_shape=(maneuver.get_reduced_feature_length(), ),
nb_actions=2,
gamma=0.99,
nb_steps_warmup_critic=200,
nb_steps_warmup_actor=200,
lr=1e-3)
agent.load_model("backends/trained_policies/" + key + "/" + key +
"_weights.h5f")
maneuver.set_low_level_trained_policy(agent.predict)
maneuver._cost_weights = (20.0 * 1e-3, 1.0 * 1e-3, 0.25 * 1e-3,
1.0 * 1e-3, 100.0 * 1e-3, 0.1 * 1e-3,
0.25 * 1e-3, 0.1 * 1e-3)
trained_policy_path = "backends/trained_policies/" + key + "/"
critic_file_exists = os.path.isfile(trained_policy_path + key + "_weights_critic.h5f")
actor_file_exists = os.path.isfile(trained_policy_path + key + "_weights_actor.h5f")
if actor_file_exists and critic_file_exists:
agent = DDPGLearner(
input_shape=(maneuver.get_reduced_feature_length(),),
nb_actions=2,
gamma=0.99,
nb_steps_warmup_critic=200,
nb_steps_warmup_actor=200,
lr=1e-3)
agent.load_model(trained_policy_path + key + "_weights.h5f")
maneuver.set_low_level_trained_policy(agent.predict)
elif not critic_file_exists and actor_file_exists:
print("\n Warning: unable to load the low-level policy of \"" + key +
"\". the file of critic weights have to be located in the same " +
"directory of the actor weights file; the manual policy will be used instead.\n")
else:
print("\n Warning: the trained low-level policy of \"" + key +
"\" does not exists; the manual policy will be used.\n")
if self.config["method"] == "mcts":
maneuver.timeout = np.inf
......
......@@ -20,7 +20,7 @@ class ManeuverBase(EpisodicEnvBase):
learning_mode = 'training'
#: timeout (i.e., time horizon for termination)
# By default, the time-out horizon is 1 as in Paxton et. al (2017).
# By default, the time-out horizon is 1.
timeout = 1
#: the option specific weight vector for cost of driving, which is
......@@ -29,8 +29,9 @@ class ManeuverBase(EpisodicEnvBase):
# _extra_action_weights_flag = True); note that a cost is defined
# as a negative reward, so a cost will be summed up to the reward
# with subtraction.
_cost_weights = (1.0 * 1e-3, 1.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.25 * 1e-3, 0.1 * 1e-3)
# TODO: remove or to provide additional functionality, keep _cost_weights in ManeuverBase here (see other TODOs in simple_intersection_env regarding _cost_weights).
_cost_weights = (10.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.05 * 1e-3, 0.1 * 1e-3)
_extra_r_terminal = None
_extra_r_on_timeout = None
......@@ -38,7 +39,7 @@ class ManeuverBase(EpisodicEnvBase):
#: the flag being False when _cost_weights is used without
# modification; If True, then the action parts of _cost_weights
# are increased for some edge cases (see the step method).
_extra_action_weights_flag = True
_extra_action_weights_flag = False
#: the extra weights on the actions added to _cost_weights
# for some edge cases when _extra_action_weights_flag = True.
......@@ -153,8 +154,7 @@ class ManeuverBase(EpisodicEnvBase):
# in this case, no additional reward by Default
# (i.e., self._extra_r_terminal = None by default).
self._terminal_reward_superposition(self._extra_r_terminal)
info[
'maneuver_termination_reason'] = 'extra_termination_condition'
info['maneuver_termination_reason'] = 'extra_termination_condition'
if self.timeout_happened:
if self._give_reward_on_timeout:
# in this case, no additional reward by Default
......@@ -292,9 +292,8 @@ class ManeuverBase(EpisodicEnvBase):
raise NotImplemented(self.__class__.__name__ +
".generate_learning_scenario is not implemented.")
def generate_validation_scenario(
self
): # Override this method in the subclass if some customization is needed.
# Override this method in the subclass if some customization is needed.
def generate_validation_scenario(self):
self.generate_learning_scenario()
self._enable_low_level_training_properties = False
......@@ -334,8 +333,7 @@ class ManeuverBase(EpisodicEnvBase):
Returns True if the condition is satisfied, and False otherwise.
"""
return not (self.env.termination_condition or self.violation_happened) and \
self.extra_initiation_condition
return not self.termination_condition and self.extra_initiation_condition
@property
def extra_initiation_condition(self):
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment