Commit 7a31ba33 authored by Jae Young Lee's avatar Jae Young Lee

Merge branch 'master' into Train_0.1m_steps_and_improve_Wait

parents f25adee3 54079509
......@@ -2,4 +2,4 @@ from .manual_policy import ManualPolicy
from .mcts_learner import MCTSLearner
from .rl_controller import RLController
from .kerasrl_learner import DDPGLearner, DQNLearner
from .online_mcts_controller import OnlineMCTSController
\ No newline at end of file
from .mcts_controller import MCTSController
\ No newline at end of file
......@@ -13,7 +13,7 @@ from rl.policy import GreedyQPolicy, EpsGreedyQPolicy, MaxBoltzmannQPolicy
from rl.callbacks import ModelIntervalCheckpoint
import numpy as np
import copy
class DDPGLearner(LearnerBase):
def __init__(self,
......@@ -275,9 +275,6 @@ class DQNLearner(LearnerBase):
model.add(Flatten(input_shape=(1, ) + self.input_shape))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='tanh'))
model.add(Dense(self.nb_actions))
print(model.summary())
......@@ -386,9 +383,10 @@ class DQNLearner(LearnerBase):
termination_reason_counter[termination_reason] += 1
else:
termination_reason_counter[termination_reason] = 1
env.reset()
if episode_reward >= success_reward_threshold:
#TODO: remove below env-specific code
if env.env.goal_achieved:
success_count += 1
env.reset()
print("Episode {}: steps:{}, reward:{}".format(
n + 1, step, episode_reward))
......@@ -416,9 +414,14 @@ class DQNLearner(LearnerBase):
action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
q_values = self.agent_model.get_modified_q_values(observation)
max_q_value = np.abs(np.max(q_values))
q_values = [np.exp(q_value / max_q_value) for q_value in q_values]
# print('softq q_values are %s' % dict(zip(self.agent_model.low_level_policy_aliases, q_values)))
# oq_values = copy.copy(q_values)
if q_values[action_num] == -np.inf:
return 0
max_q_value = np.max(q_values)
q_values = [np.exp(q_value - max_q_value) for q_value in q_values]
relevant = q_values[action_num] / np.sum(q_values)
# print('softq: %s -> %s' % (oq_values, relevant))
return relevant
......@@ -543,6 +546,7 @@ class DQNAgentOverOptions(DQNAgent):
self.recent_observation = observation
self.recent_action = action
# print('forward gives %s from %s' % (action, dict(zip(self.low_level_policy_aliases, q_values))))
return action
def get_modified_q_values(self, observation):
......
......@@ -3,9 +3,8 @@ from .mcts_learner import MCTSLearner
import tqdm
import numpy as np
class OnlineMCTSController(ControllerBase):
"""Online MCTS."""
class MCTSController(ControllerBase):
"""MCTS Controller."""
def __init__(self, env, low_level_policies, start_node_alias):
"""Constructor for manual policy execution.
......@@ -14,13 +13,15 @@ class OnlineMCTSController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super(OnlineMCTSController, self).__init__(env, low_level_policies,
super(MCTSController, self).__init__(env, low_level_policies,
start_node_alias)
self.curr_node_alias = start_node_alias
self.controller_args_defaults = {
"predictor": None,
"max_depth": 5, # MCTS depth
"nb_traversals": 30, # MCTS traversals before decision
"max_depth": 10, # MCTS depth
"nb_traversals": 100, # MCTS traversals before decision
"debug": False,
"rollout_timeout": 500
}
def set_current_node(self, node_alias):
......@@ -30,11 +31,21 @@ class OnlineMCTSController(ControllerBase):
self.env.set_ego_info_text(node_alias)
def change_low_level_references(self, env_copy):
# Create a copy of the environment and change references in low level policies.
"""Change references in low level policies by updating the environment
with the copy of the environment.
Args:
env_copy: reference to copy of the environment
"""
self.env = env_copy
for policy in self.low_level_policies.values():
policy.env = env_copy
def check_env(self, x):
"""Prints the object id of the environment. Debugging function."""
print('%s: self.env is %s' % (x, str(id(self.env))))
def can_transition(self):
return not self.env.is_terminal()
......@@ -45,29 +56,54 @@ class OnlineMCTSController(ControllerBase):
"predictor is not set. Use set_controller_args().")
# Store the env at this point
orig_env = self.env
# self.check_env('i')
np.random.seed()
# Change low level references before init MCTSLearner instance
env_before_mcts = orig_env.copy()
self.change_low_level_references(env_before_mcts)
print('Current Node: %s' % self.curr_node_alias)
mcts = MCTSLearner(self.env, self.low_level_policies,
self.curr_node_alias)
mcts.max_depth = self.max_depth
mcts.set_controller_args(predictor=self.predictor)
# self.check_env('b4')
# print('Current Node: %s' % self.curr_node_alias)
if not hasattr(self, 'mcts'):
if self.debug:
print('Creating MCTS Tree: max depth {}'.format(self.max_depth))
self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth,
debug=self.debug, rollout_timeout=self.rollout_timeout)
self.mcts.set_controller_args(predictor=self.predictor)
if self.debug:
print('')
# Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for num_epoch in range(
self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
mcts.curr_node_num = 0
num_epoch = 0
if not self.debug:
progress = tqdm.tqdm(total=self.nb_traversals-self.mcts.tree.root.N)
while num_epoch < self.nb_traversals: # tqdm
if self.mcts.tree.root.N >= self.nb_traversals:
break
env_begin_epoch = env_before_mcts.copy()
self.change_low_level_references(env_begin_epoch)
# self.check_env('e%d' % num_epoch)
init_obs = self.env.get_features_tuple()
v, all_ep_R = mcts.traverse(init_obs)
self.mcts.env = env_begin_epoch
if self.debug:
print('Search %d: ' % num_epoch, end=' ')
success = self.mcts.search(init_obs)
num_epoch += 1
if not self.debug:
progress.update(1)
if not self.debug:
progress.close()
self.change_low_level_references(orig_env)
# self.check_env('p')
# Find the nodes from the root node
mcts.curr_node_num = 0
print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = mcts.get_best_node(
self.env.get_features_tuple(), use_ucb=False)
print('MCTS suggested next option: %s' % node_after_transition)
self.set_current_node(node_after_transition)
# print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = self.mcts.best_action(self.mcts.tree.root, 0)
if self.debug:
print('MCTS suggested next option: %s' % node_after_transition)
p = {'overall': self.mcts.tree.root.Q * 1.0 / self.mcts.tree.root.N}
for edge in self.mcts.tree.root.edges.keys():
next_node = self.mcts.tree.nodes[self.mcts.tree.root.edges[edge]]
p[edge] = next_node.Q * 1.0 / next_node.N
if self.debug:
print('Btw: %s' % str(p))
self.mcts.tree.reconstruct(node_after_transition)
self.set_current_node(node_after_transition)
\ No newline at end of file
This diff is collapsed.
......@@ -4,7 +4,7 @@ from backends import DDPGLearner, DQNLearner, MCTSLearner
import numpy as np
import tqdm
import argparse
import time, datetime
import sys
......@@ -28,310 +28,132 @@ class Logger(object):
sys.stdout = Logger()
# TODO: make a separate file for this function.
def mcts_training(nb_traversals,
save_every=20,
visualize=False,
load_saved=False,
save_file="mcts.pickle"):
"""Do RL of the low-level policy of the given maneuver and test it.
Args:
nb_traversals: number of MCTS traversals
save_every: save at every these many traversals
visualize: visualization / rendering
"""
# initialize the numpy random number generator
np.random.seed()
# load options graph
options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv)
options.load_trained_low_level_policies()
agent = DQNLearner(
input_shape=(50, ),
nb_actions=options.get_number_of_nodes(),
low_level_policies=options.maneuvers)
agent.load_model(
"backends/trained_policies/highlevel/highlevel_weights.h5f")
options.set_controller_args(
predictor=agent.get_softq_value_using_option_alias)
options.controller.max_depth = 20
if load_saved:
options.controller.load_model(save_file)
total_epochs = nb_traversals // save_every
trav_num = 1
print('Total number of epochs = %d' % total_epochs)
for num_epoch in range(total_epochs):
last_rewards = []
beg_trav_num = trav_num
for num_traversal in tqdm.tqdm(range(save_every)):
options.controller.curr_node_num = 0
init_obs = options.reset()
v, all_ep_R = options.controller.traverse(
init_obs, visualize=visualize)
last_rewards += [all_ep_R]
trav_num += 1
options.controller.save_model(save_file)
success = lambda x: x > 50
success_rate = np.sum(list(map(
success, last_rewards))) / (len(last_rewards) * 1.0)
print('success rate: %f' % success_rate)
print('Average Reward (%d-%d): %f\n' % (beg_trav_num, trav_num - 1,
np.mean(last_rewards)))
def mcts_evaluation(nb_traversals,
num_trials=5,
def mcts_evaluation(depth,
nb_traversals,
nb_episodes,
nb_trials,
visualize=False,
save_file="mcts.pickle",
pretrained=False):
debug=False):
"""Do RL of the low-level policy of the given maneuver and test it.
Args:
nb_traversals: number of MCTS traversals
save_every: save at every these many traversals
depth: depth of each tree search
nb_traversals: number of MCTS traversals per episodes
nb_episodes: number of episodes per trial
nb_trials: number of trials
visualize: visualization / rendering
debug: whether or not to show debug information
"""
# initialize the numpy random number generator
np.random.seed()
# load options graph
options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv)
# load config and maneuvers
options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv,
randomize_special_scenarios=True)
options.load_trained_low_level_policies()
# load high level policy for UCT prediction
agent = DQNLearner(
input_shape=(50, ),
nb_actions=options.get_number_of_nodes(),
low_level_policies=options.maneuvers)
agent.load_model(
"backends/trained_policies/highlevel/highlevel_weights.h5f")
options.set_controller_args(
predictor=agent.get_softq_value_using_option_alias)
options.controller.max_depth = 20
if pretrained:
save_file = "backends/trained_policies/mcts/" + save_file
success_list = []
print('Total number of trials = %d' % num_trials)
for trial in range(num_trials):
num_successes = 0
options.controller.load_model(save_file)
for num_traversal in tqdm.tqdm(range(nb_traversals)):
options.controller.curr_node_num = 0
init_obs = options.reset()
v, all_ep_R = options.controller.traverse(
init_obs, visualize=visualize)
if all_ep_R > 50:
num_successes += 1
print("\nTrial {}: success: {}".format(trial + 1, num_successes))
success_list.append(num_successes)
print("\nSuccess: Avg: {}, Std: {}".format(
np.mean(success_list), np.std(success_list)))
def online_mcts(nb_episodes=10):
# MCTS visualization is off
# initialize the numpy random number generator
np.random.seed()
# load options graph
options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv)
options.load_trained_low_level_policies()
agent = DQNLearner(
input_shape=(50, ),
nb_actions=options.get_number_of_nodes(),
low_level_policies=options.maneuvers)
agent.load_model(
"backends/trained_policies/highlevel/highlevel_weights_772.h5f")
# set predictor
options.set_controller_args(
predictor=agent.get_softq_value_using_option_alias)
# Loop
num_successes = 0
for num_ep in range(nb_episodes):
init_obs = options.reset()
episode_reward = 0
first_time = True
while not options.env.is_terminal():
if first_time:
first_time = False
else:
print('Stepping through ...')
features, R, terminal, info = options.controller.\
step_current_node(visualize_low_level_steps=True)
episode_reward += R
print('Intermediate Reward: %f (ego x = %f)' %
(R, options.env.vehs[0].x))
print('')
if options.controller.can_transition():
options.controller.do_transition()
print('')
print('EPISODE %d: Reward = %f' % (num_ep, episode_reward))
print('')
print('')
if episode_reward > 50: num_successes += 1
print("Policy succeeded {} times!".format(num_successes))
def evaluate_online_mcts(nb_episodes=20, nb_trials=5):
# MCTS visualization is off
# initialize the numpy random number generator
np.random.seed()
# load options graph
options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv)
options.load_trained_low_level_policies()
agent = DQNLearner(
input_shape=(50, ),
nb_actions=options.get_number_of_nodes(),
low_level_policies=options.maneuvers)
agent.load_model(
"backends/trained_policies/highlevel/highlevel_weights_772.h5f")
options.set_controller_args(
predictor=agent.get_softq_value_using_option_alias)
predictor=agent.get_softq_value_using_option_alias,
max_depth=depth,
nb_traversals=nb_traversals,
debug=debug)
# Evaluate
success_list = []
print("\nConducting {} trials of {} episodes each".format(
nb_trials, nb_episodes))
success_list = []
termination_reason_list = {}
for trial in range(nb_trials):
# Loop
overall_reward_list = []
overall_success_accuracy = []
for num_tr in range(nb_trials):
num_successes = 0
termination_reason_counter = {}
reward_list = []
for num_ep in range(nb_episodes):
init_obs = options.reset()
episode_reward = 0
first_time = True
start_time = time.time()
while not options.env.is_terminal():
if first_time:
first_time = False
else:
print('Stepping through ...')
# print('Stepping through ...')
features, R, terminal, info = options.controller.\
step_current_node(visualize_low_level_steps=True)
step_current_node(visualize_low_level_steps=visualize)
episode_reward += R
print('Intermediate Reward: %f (ego x = %f)' %
(R, options.env.vehs[0].x))
print('')
if terminal:
if 'episode_termination_reason' in info:
termination_reason = info[
'episode_termination_reason']
if termination_reason in termination_reason_counter:
termination_reason_counter[
termination_reason] += 1
else:
termination_reason_counter[
termination_reason] = 1
# print('Intermediate Reward: %f (ego x = %f)' %
# (R, options.env.vehs[0].x))
# print('')
if options.controller.can_transition():
options.controller.do_transition()
print('')
print('EPISODE %d: Reward = %f' % (num_ep, episode_reward))
print('')
print('')
if episode_reward > 50: num_successes += 1
print("\nTrial {}: success: {}".format(trial + 1, num_successes))
success_list.append(num_successes)
for reason, count in termination_reason_counter.items():
if reason in termination_reason_list:
termination_reason_list[reason].append(count)
else:
termination_reason_list[reason] = [count]
success_list = np.array(success_list)
print("\nSuccess: Avg: {}, Std: {}".format(
np.mean(success_list), np.std(success_list)))
print("Termination reason(s):")
for reason, count_list in termination_reason_list.items():
count_list = np.array(count_list)
print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
np.std(count_list)))
end_time = time.time()
total_time = int(end_time-start_time)
if options.env.goal_achieved:
num_successes += 1
print('Episode {}: Reward = {} ({})'.format(num_ep, episode_reward,
datetime.timedelta(seconds=total_time)))
reward_list += [episode_reward]
print("Trial {}: Reward = (Avg: {}, Std: {}), Successes: {}/{}".\
format(num_tr, np.mean(reward_list), np.std(reward_list), \
num_successes, nb_episodes))
overall_reward_list += reward_list
overall_success_accuracy += [num_successes * 1.0 / nb_episodes]
print('Overall: Reward = (Avg: {}, Std: {}), Success = (Avg: {}, Std: {})'.\
format(np.mean(overall_reward_list), np.std(overall_reward_list),
np.mean(overall_success_accuracy), np.std(overall_success_accuracy)))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--train",
help=
"Train an offline mcts with default settings. Always saved in root folder.",
action="store_true")
parser.add_argument(
"--evaluate",
help="Evaluate over n trials. "
"Uses backends/trained_policies/mcts/mcts.pickle by default",
action="store_true")
parser.add_argument(
"--saved_policy_in_root",
help=
"Use saved policies in root of project rather than backends/trained_policies/mcts/",
action="store_true")
parser.add_argument(
"--load_saved",
help="Load a saved policy from root folder first before training",
help="Evaluate over n trials, no visualization by default.",
action="store_true")
parser.add_argument(
"--visualize",
help=
"Visualize the training. Testing is always visualized. Evaluation is not visualized by default",
"Visualize the training.",
action="store_true")
parser.add_argument(
"--nb_traversals",
help="Number of traversals to perform. Default is 1000",
default=1000,
"--depth",
help="Max depth of tree per episode. Default is 10",
default=10,
type=int)
parser.add_argument(
"--save_every",
help=
"Saves every n traversals. Saves in root by default. Default is 500",
default=500,
"--nb_traversals",
help="Number of traversals to perform per episode. Default is 100",
default=100,
type=int)
parser.add_argument(
"--nb_traversals_for_test",
help="Number of episodes to evaluate. Default is 100",
default=100,
"--nb_episodes",
help="Number of episodes per trial to evaluate. Default is 10",
default=10,
type=int)
parser.add_argument(
"--nb_trials",
help="Number of trials to evaluate. Default is 10",
default=10,
help="Number of trials to evaluate. Default is 1",
default=1,
type=int)
parser.add_argument(
"--save_file",
help=
"filename to save/load the trained policy. Location is as specified by --saved_policy_in_root. Default name is mcts.pickle",
default="mcts.pickle")
"--debug",
help="Show debug output. Default is false",
action="store_true")
args = parser.parse_args()
if args.train:
mcts_training(
nb_traversals=args.nb_traversals,
save_every=args.save_every,
visualize=args.visualize,
load_saved=args.load_saved,
save_file=args.save_file)
if args.evaluate:
mcts_evaluation(
nb_traversals=args.nb_traversals_for_test,
num_trials=args.nb_trials,
depth=args.depth,
nb_traversals=args.nb_traversals,
nb_episodes=args.nb_episodes,
nb_trials=args.nb_trials,
visualize=args.visualize,
pretrained=not args.saved_policy_in_root,
save_file=args.save_file)
debug=args.debug)
{
"nodes": {
"wait": "MCTSWait",
"follow": "MCTSFollow",
"stop": "MCTSStop",
"changelane": "MCTSChangeLane",
"keeplane": "MCTSKeepLane"
"wait": "Wait",
"follow": "Follow",
"stop": "Stop",
"changelane": "ChangeLane",
"keeplane": "KeepLane"
},
"edges": {
......
import json
import os # for the use of os.path.isfile
from .simple_intersection.maneuvers import *
from .simple_intersection.mcts_maneuvers import *
from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSController, ManualPolicy
from backends import RLController, DDPGLearner, MCTSController, ManualPolicy
class OptionsGraph:
"""Represent the options graph as a graph like structure. The configuration
......@@ -67,11 +65,8 @@ class OptionsGraph:
self.controller = ManualPolicy(self.env, self.maneuvers, self.adj,
self.start_node_alias)
elif self.config["method"] == "mcts":
self.controller = MCTSLearner(self.env, self.maneuvers,
self.start_node_alias)
elif self.config["method"] == "online_mcts":
self.controller = OnlineMCTSController(self.env, self.maneuvers,
self.start_node_alias)
self.controller = MCTSController(self.env, self.maneuvers,
self.start_node_alias)
else:
raise Exception(self.__class__.__name__ + \
"Controller to be used not specified")
......@@ -156,6 +151,7 @@ class OptionsGraph:
# TODO: error handling
def load_trained_low_level_policies(self):
for key, maneuver in self.maneuvers.items():
# TODO: Ensure that for manual policies, nothing is loaded
trained_policy_path = "backends/trained_policies/" + key + "/"
critic_file_exists = os.path.isfile(trained_policy_path + key + "_weights_critic.h5f")
actor_file_exists = os.path.isfile(trained_policy_path + key + "_weights_actor.h5f")
......@@ -180,9 +176,6 @@ class OptionsGraph:
print("\n Warning: the trained low-level policy of \"" + key +
"\" does not exists; the manual policy will be used.\n")
if self.config["method"] == "mcts":
maneuver.timeout = np.inf
def get_number_of_nodes(self):
return len(self.maneuvers)
......
from .maneuver_base import ManeuverBase
from env.simple_intersection.constants import *
import env.simple_intersection.road_geokinemetry as rd
from env.simple_intersection.features import extract_ego_features, extract_other_veh_features
from verifier.simple_intersection import LTLProperty
import numpy as np