Commit a90b4bc5 authored by Jae Young Lee's avatar Jae Young Lee

Add and train more low-level policies, train a high-level policy.

The high-level policy was trained without changelane maneuver but with immediatestop maneuver. Two problems remain: 1) the agent chooses changelane maneuver too frequently; 2) before the stop region, immediatestop maneuver works but was not chosen property after 2.5m the high-level policy training...
parent 70ad9bf5
......@@ -51,7 +51,7 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info
'''
# TODO: this is never called when you test high-level policy rather than train...
def step_current_node(self, visualize_low_level_steps=False):
total_reward = 0
self.node_terminal_state_reached = False
......@@ -59,6 +59,7 @@ class ControllerBase(PolicyBase):
observation, reward, terminal, info = self.low_level_step_current_node()
if visualize_low_level_steps:
self.env.render()
# TODO: make the total_reward discounted....
total_reward += reward
total_reward += self.current_node.high_level_extra_reward
......
......@@ -269,9 +269,12 @@ class DQNLearner(LearnerBase):
"""
model = Sequential()
model.add(Flatten(input_shape=(1, ) + self.input_shape))
model.add(Dense(32))
#model.add(Dense(64))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(32))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(self.nb_actions))
model.add(Activation('linear'))
......@@ -435,6 +438,7 @@ class DQNAgentOverOptions(DQNAgent):
initiation conditions."""
invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases):
# TODO: Locate reset_maneuver to another place as this is a "get" function.
self.low_level_policies[option_alias].reset_maneuver()
if not self.low_level_policies[option_alias].initiation_condition:
invalid_node_indices.append(index)
......
......@@ -30,6 +30,7 @@ class ManualPolicy(ControllerBase):
new_node = None
if self.low_level_policies[self.current_node].termination_condition:
for next_node in self.adj[self.current_node]:
self.low_level_policies[next_node].reset_maneuver()
if self.low_level_policies[next_node].initiation_condition:
new_node = next_node
break # change current_node to the highest priority next node
......
......@@ -3,7 +3,7 @@
"wait": "Wait",
"follow": "Follow",
"stop": "Stop",
"changelane": "ChangeLane",
"immediatestop": "ImmediateStop",
"keeplane": "KeepLane"
},
......
......@@ -174,6 +174,6 @@ class Features(object):
# Add buffer features to make a fixed length feature vector
for i in range(MAX_NUM_VEHICLES - len(self.other_vehs)):
feature += (0.0, 0.0, 0.0, 0.0, -1)
feature += (0.0, 0.0, 0.0, 0.0, -1.0)
return feature
......@@ -1156,12 +1156,13 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
Returns True if the environment has terminated
"""
model_checks_violated = (self._LTL_preconditions_enable and \
self.current_model_checking_result())
model_checks_violated = self._LTL_preconditions_enable and \
self.current_model_checking_result()
reached_goal = self._terminate_in_goal and self.goal_achieved
self._check_collisions()
self._check_ego_theta_out_of_range()
terminated = self.termination_condition
return model_checks_violated or reached_goal or terminated
@property
......@@ -1181,7 +1182,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return (self.ego.x >= rd.hlanes.end_pos) and \
not self.collision_happened and \
not self.ego.APs['over_speed_limit']
(self.ego.v <= 1.1*rd.speed_limit)
def reset(self):
"""Gym compliant reset function.
......@@ -1229,7 +1230,6 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
self.window.dispatch_events()
# Text information about ego vehicle's states
# Right now, we are only training one option (Stop)
info = "Ego Attributes:" + get_APs(
self, EGO_INDEX, 'in_stop_region',
'has_entered_stop_region', 'has_stopped_in_stop_region',
......
......@@ -11,7 +11,7 @@ def high_level_policy_training(nb_steps=25000,
load_weights=False,
training=True,
testing=True,
nb_episodes_for_test=10,
nb_episodes_for_test=20,
max_nb_steps=100,
visualize=False,
tensorboard=False,
......@@ -63,8 +63,7 @@ def high_level_policy_training(nb_steps=25000,
agent.save_model(save_path)
if testing:
options.set_controller_policy(agent.predict)
agent.test_model(options, nb_episodes=nb_episodes_for_test)
high_level_policy_testing(nb_episodes_for_test=nb_episodes_for_test)
return agent
......@@ -228,7 +227,6 @@ if __name__ == "__main__":
load_weights=args.load_weights,
save_path=args.save_file,
tensorboard=args.tensorboard,
nb_episodes_for_test=20,
visualize=args.visualize)
if args.test:
......
import json
import os # for the use of os.path.isfile
from .simple_intersection.maneuvers import *
from .simple_intersection.mcts_maneuvers import *
from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSController, ManualPolicy
......@@ -155,19 +156,34 @@ class OptionsGraph:
# TODO: error handling
def load_trained_low_level_policies(self):
for key, maneuver in self.maneuvers.items():
agent = DDPGLearner(
input_shape=(maneuver.get_reduced_feature_length(), ),
nb_actions=2,
gamma=0.99,
nb_steps_warmup_critic=200,
nb_steps_warmup_actor=200,
lr=1e-3)
agent.load_model("backends/trained_policies/" + key + "/" + key +
"_weights.h5f")
maneuver.set_low_level_trained_policy(agent.predict)
maneuver._cost_weights = (20.0 * 1e-3, 1.0 * 1e-3, 0.25 * 1e-3,
1.0 * 1e-3, 100.0 * 1e-3, 0.1 * 1e-3,
0.25 * 1e-3, 0.1 * 1e-3)
trained_policy_path = "backends/trained_policies/" + key + "/"
critic_file_exists = os.path.isfile(trained_policy_path + key + "_weights_critic.h5f")
actor_file_exists = os.path.isfile(trained_policy_path + key + "_weights_actor.h5f")
if actor_file_exists and critic_file_exists:
agent = DDPGLearner(
input_shape=(maneuver.get_reduced_feature_length(),),
nb_actions=2,
gamma=0.99,
nb_steps_warmup_critic=200,
nb_steps_warmup_actor=200,
lr=1e-3)
agent.load_model(trained_policy_path + key + "_weights.h5f")
maneuver.set_low_level_trained_policy(agent.predict)
elif not critic_file_exists and actor_file_exists:
print("\n Warning: unable to load the low-level policy of \"" + key +
"\". the file of critic weights have to be located in the same " +
"directory of the actor weights file; the manual policy will be used instead.\n")
else:
print("\n Warning: the trained low-level policy of \"" + key +
"\" does not exists; the manual policy will be used.\n")
# setting the cost weights for high-level policy training.
# TODO: this shouldn't be initialized here, but within ManeuverBase class (e.g. make some flag indicating high-level training and set the weights)...
maneuver._cost_weights = (100.0 * 1e-3, 10.0 * 1e-3, 0.25 * 1e-3, 1.0 * 1e-3,
100.0 * 1e-3, 0.1 * 1e-3, 0.25 * 1e-3, 0.1 * 1e-3)
if self.config["method"] == "mcts":
maneuver.timeout = np.inf
......
......@@ -20,7 +20,7 @@ class ManeuverBase(EpisodicEnvBase):
learning_mode = 'training'
#: timeout (i.e., time horizon for termination)
# By default, the time-out horizon is 1 as in Paxton et. al (2017).
# By default, the time-out horizon is 1.
timeout = 1
#: the option specific weight vector for cost of driving, which is
......@@ -153,8 +153,7 @@ class ManeuverBase(EpisodicEnvBase):
# in this case, no additional reward by Default
# (i.e., self._extra_r_terminal = None by default).
self._terminal_reward_superposition(self._extra_r_terminal)
info[
'maneuver_termination_reason'] = 'extra_termination_condition'
info['maneuver_termination_reason'] = 'extra_termination_condition'
if self.timeout_happened:
if self._give_reward_on_timeout:
# in this case, no additional reward by Default
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment