Commit a90b4bc5 authored by Jae Young Lee's avatar Jae Young Lee

Add and train more low-level policies, train a high-level policy.

The high-level policy was trained without changelane maneuver but with immediatestop maneuver. Two problems remain: 1) the agent chooses changelane maneuver too frequently; 2) before the stop region, immediatestop maneuver works but was not chosen property after 2.5m the high-level policy training...
parent 70ad9bf5
...@@ -51,7 +51,7 @@ class ControllerBase(PolicyBase): ...@@ -51,7 +51,7 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info Returns state at end of node execution, total reward, epsiode_termination_flag, info
''' '''
# TODO: this is never called when you test high-level policy rather than train...
def step_current_node(self, visualize_low_level_steps=False): def step_current_node(self, visualize_low_level_steps=False):
total_reward = 0 total_reward = 0
self.node_terminal_state_reached = False self.node_terminal_state_reached = False
...@@ -59,6 +59,7 @@ class ControllerBase(PolicyBase): ...@@ -59,6 +59,7 @@ class ControllerBase(PolicyBase):
observation, reward, terminal, info = self.low_level_step_current_node() observation, reward, terminal, info = self.low_level_step_current_node()
if visualize_low_level_steps: if visualize_low_level_steps:
self.env.render() self.env.render()
# TODO: make the total_reward discounted....
total_reward += reward total_reward += reward
total_reward += self.current_node.high_level_extra_reward total_reward += self.current_node.high_level_extra_reward
......
...@@ -269,9 +269,12 @@ class DQNLearner(LearnerBase): ...@@ -269,9 +269,12 @@ class DQNLearner(LearnerBase):
""" """
model = Sequential() model = Sequential()
model.add(Flatten(input_shape=(1, ) + self.input_shape)) model.add(Flatten(input_shape=(1, ) + self.input_shape))
model.add(Dense(32)) #model.add(Dense(64))
model.add(Dense(64))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Dense(32)) model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(64))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Dense(self.nb_actions)) model.add(Dense(self.nb_actions))
model.add(Activation('linear')) model.add(Activation('linear'))
...@@ -435,6 +438,7 @@ class DQNAgentOverOptions(DQNAgent): ...@@ -435,6 +438,7 @@ class DQNAgentOverOptions(DQNAgent):
initiation conditions.""" initiation conditions."""
invalid_node_indices = list() invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases): for index, option_alias in enumerate(self.low_level_policy_aliases):
# TODO: Locate reset_maneuver to another place as this is a "get" function.
self.low_level_policies[option_alias].reset_maneuver() self.low_level_policies[option_alias].reset_maneuver()
if not self.low_level_policies[option_alias].initiation_condition: if not self.low_level_policies[option_alias].initiation_condition:
invalid_node_indices.append(index) invalid_node_indices.append(index)
......
...@@ -30,6 +30,7 @@ class ManualPolicy(ControllerBase): ...@@ -30,6 +30,7 @@ class ManualPolicy(ControllerBase):
new_node = None new_node = None
if self.low_level_policies[self.current_node].termination_condition: if self.low_level_policies[self.current_node].termination_condition:
for next_node in self.adj[self.current_node]: for next_node in self.adj[self.current_node]:
self.low_level_policies[next_node].reset_maneuver()
if self.low_level_policies[next_node].initiation_condition: if self.low_level_policies[next_node].initiation_condition:
new_node = next_node new_node = next_node
break # change current_node to the highest priority next node break # change current_node to the highest priority next node
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"wait": "Wait", "wait": "Wait",
"follow": "Follow", "follow": "Follow",
"stop": "Stop", "stop": "Stop",
"changelane": "ChangeLane", "immediatestop": "ImmediateStop",
"keeplane": "KeepLane" "keeplane": "KeepLane"
}, },
......
...@@ -174,6 +174,6 @@ class Features(object): ...@@ -174,6 +174,6 @@ class Features(object):
# Add buffer features to make a fixed length feature vector # Add buffer features to make a fixed length feature vector
for i in range(MAX_NUM_VEHICLES - len(self.other_vehs)): for i in range(MAX_NUM_VEHICLES - len(self.other_vehs)):
feature += (0.0, 0.0, 0.0, 0.0, -1) feature += (0.0, 0.0, 0.0, 0.0, -1.0)
return feature return feature
...@@ -1156,12 +1156,13 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1156,12 +1156,13 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
Returns True if the environment has terminated Returns True if the environment has terminated
""" """
model_checks_violated = (self._LTL_preconditions_enable and \ model_checks_violated = self._LTL_preconditions_enable and \
self.current_model_checking_result()) self.current_model_checking_result()
reached_goal = self._terminate_in_goal and self.goal_achieved reached_goal = self._terminate_in_goal and self.goal_achieved
self._check_collisions() self._check_collisions()
self._check_ego_theta_out_of_range() self._check_ego_theta_out_of_range()
terminated = self.termination_condition terminated = self.termination_condition
return model_checks_violated or reached_goal or terminated return model_checks_violated or reached_goal or terminated
@property @property
...@@ -1181,7 +1182,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1181,7 +1182,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return (self.ego.x >= rd.hlanes.end_pos) and \ return (self.ego.x >= rd.hlanes.end_pos) and \
not self.collision_happened and \ not self.collision_happened and \
not self.ego.APs['over_speed_limit'] (self.ego.v <= 1.1*rd.speed_limit)
def reset(self): def reset(self):
"""Gym compliant reset function. """Gym compliant reset function.
...@@ -1229,7 +1230,6 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1229,7 +1230,6 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
self.window.dispatch_events() self.window.dispatch_events()
# Text information about ego vehicle's states # Text information about ego vehicle's states
# Right now, we are only training one option (Stop)
info = "Ego Attributes:" + get_APs( info = "Ego Attributes:" + get_APs(
self, EGO_INDEX, 'in_stop_region', self, EGO_INDEX, 'in_stop_region',
'has_entered_stop_region', 'has_stopped_in_stop_region', 'has_entered_stop_region', 'has_stopped_in_stop_region',
......
...@@ -11,7 +11,7 @@ def high_level_policy_training(nb_steps=25000, ...@@ -11,7 +11,7 @@ def high_level_policy_training(nb_steps=25000,
load_weights=False, load_weights=False,
training=True, training=True,
testing=True, testing=True,
nb_episodes_for_test=10, nb_episodes_for_test=20,
max_nb_steps=100, max_nb_steps=100,
visualize=False, visualize=False,
tensorboard=False, tensorboard=False,
...@@ -63,8 +63,7 @@ def high_level_policy_training(nb_steps=25000, ...@@ -63,8 +63,7 @@ def high_level_policy_training(nb_steps=25000,
agent.save_model(save_path) agent.save_model(save_path)
if testing: if testing:
options.set_controller_policy(agent.predict) high_level_policy_testing(nb_episodes_for_test=nb_episodes_for_test)
agent.test_model(options, nb_episodes=nb_episodes_for_test)
return agent return agent
...@@ -228,7 +227,6 @@ if __name__ == "__main__": ...@@ -228,7 +227,6 @@ if __name__ == "__main__":
load_weights=args.load_weights, load_weights=args.load_weights,
save_path=args.save_file, save_path=args.save_file,
tensorboard=args.tensorboard, tensorboard=args.tensorboard,
nb_episodes_for_test=20,
visualize=args.visualize) visualize=args.visualize)
if args.test: if args.test:
......
import json import json
import os # for the use of os.path.isfile
from .simple_intersection.maneuvers import * from .simple_intersection.maneuvers import *
from .simple_intersection.mcts_maneuvers import * from .simple_intersection.mcts_maneuvers import *
from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSController, ManualPolicy from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSController, ManualPolicy
...@@ -155,19 +156,34 @@ class OptionsGraph: ...@@ -155,19 +156,34 @@ class OptionsGraph:
# TODO: error handling # TODO: error handling
def load_trained_low_level_policies(self): def load_trained_low_level_policies(self):
for key, maneuver in self.maneuvers.items(): for key, maneuver in self.maneuvers.items():
agent = DDPGLearner( trained_policy_path = "backends/trained_policies/" + key + "/"
input_shape=(maneuver.get_reduced_feature_length(), ), critic_file_exists = os.path.isfile(trained_policy_path + key + "_weights_critic.h5f")
nb_actions=2, actor_file_exists = os.path.isfile(trained_policy_path + key + "_weights_actor.h5f")
gamma=0.99,
nb_steps_warmup_critic=200, if actor_file_exists and critic_file_exists:
nb_steps_warmup_actor=200, agent = DDPGLearner(
lr=1e-3) input_shape=(maneuver.get_reduced_feature_length(),),
agent.load_model("backends/trained_policies/" + key + "/" + key + nb_actions=2,
"_weights.h5f") gamma=0.99,
maneuver.set_low_level_trained_policy(agent.predict) nb_steps_warmup_critic=200,
maneuver._cost_weights = (20.0 * 1e-3, 1.0 * 1e-3, 0.25 * 1e-3, nb_steps_warmup_actor=200,
1.0 * 1e-3, 100.0 * 1e-3, 0.1 * 1e-3, lr=1e-3)
0.25 * 1e-3, 0.1 * 1e-3) agent.load_model(trained_policy_path + key + "_weights.h5f")
maneuver.set_low_level_trained_policy(agent.predict)
elif not critic_file_exists and actor_file_exists:
print("\n Warning: unable to load the low-level policy of \"" + key +
"\". the file of critic weights have to be located in the same " +
"directory of the actor weights file; the manual policy will be used instead.\n")
else:
print("\n Warning: the trained low-level policy of \"" + key +
"\" does not exists; the manual policy will be used.\n")
# setting the cost weights for high-level policy training.
# TODO: this shouldn't be initialized here, but within ManeuverBase class (e.g. make some flag indicating high-level training and set the weights)...
maneuver._cost_weights = (100.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.25 * 1e-3, 0.1 * 1e-3)
if self.config["method"] == "mcts": if self.config["method"] == "mcts":
maneuver.timeout = np.inf maneuver.timeout = np.inf
......
...@@ -20,7 +20,7 @@ class ManeuverBase(EpisodicEnvBase): ...@@ -20,7 +20,7 @@ class ManeuverBase(EpisodicEnvBase):
learning_mode = 'training' learning_mode = 'training'
#: timeout (i.e., time horizon for termination) #: timeout (i.e., time horizon for termination)
# By default, the time-out horizon is 1 as in Paxton et. al (2017). # By default, the time-out horizon is 1.
timeout = 1 timeout = 1
#: the option specific weight vector for cost of driving, which is #: the option specific weight vector for cost of driving, which is
...@@ -153,8 +153,7 @@ class ManeuverBase(EpisodicEnvBase): ...@@ -153,8 +153,7 @@ class ManeuverBase(EpisodicEnvBase):
# in this case, no additional reward by Default # in this case, no additional reward by Default
# (i.e., self._extra_r_terminal = None by default). # (i.e., self._extra_r_terminal = None by default).
self._terminal_reward_superposition(self._extra_r_terminal) self._terminal_reward_superposition(self._extra_r_terminal)
info[ info['maneuver_termination_reason'] = 'extra_termination_condition'
'maneuver_termination_reason'] = 'extra_termination_condition'
if self.timeout_happened: if self.timeout_happened:
if self._give_reward_on_timeout: if self._give_reward_on_timeout:
# in this case, no additional reward by Default # in this case, no additional reward by Default
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment