Commit 788a93ad authored by Jae Young Lee's avatar Jae Young Lee

Merge branch 'master' into improve_Wait_and_highlevel

parents da44ef09 230741b9
......@@ -2,4 +2,4 @@ from .manual_policy import ManualPolicy
from .mcts_learner import MCTSLearner
from .rl_controller import RLController
from .kerasrl_learner import DDPGLearner, DQNLearner
from .online_mcts_controller import OnlineMCTSController
\ No newline at end of file
from .mcts_controller import MCTSController
\ No newline at end of file
......@@ -13,7 +13,7 @@ from rl.policy import GreedyQPolicy, EpsGreedyQPolicy, MaxBoltzmannQPolicy
from rl.callbacks import ModelIntervalCheckpoint
import numpy as np
import copy
class DDPGLearner(LearnerBase):
def __init__(self,
......@@ -413,9 +413,14 @@ class DQNLearner(LearnerBase):
action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
q_values = self.agent_model.get_modified_q_values(observation)
max_q_value = np.abs(np.max(q_values))
q_values = [np.exp(q_value / max_q_value) for q_value in q_values]
# print('softq q_values are %s' % dict(zip(self.agent_model.low_level_policy_aliases, q_values)))
# oq_values = copy.copy(q_values)
if q_values[action_num] == -np.inf:
return 0
max_q_value = np.max(q_values)
q_values = [np.exp(q_value - max_q_value) for q_value in q_values]
relevant = q_values[action_num] / np.sum(q_values)
# print('softq: %s -> %s' % (oq_values, relevant))
return relevant
......@@ -540,6 +545,7 @@ class DQNAgentOverOptions(DQNAgent):
self.recent_observation = observation
self.recent_action = action
# print('forward gives %s from %s' % (action, dict(zip(self.low_level_policy_aliases, q_values))))
return action
def get_modified_q_values(self, observation):
......
......@@ -3,9 +3,8 @@ from .mcts_learner import MCTSLearner
import tqdm
import numpy as np
class OnlineMCTSController(ControllerBase):
"""Online MCTS."""
class MCTSController(ControllerBase):
"""MCTS Controller."""
def __init__(self, env, low_level_policies, start_node_alias):
"""Constructor for manual policy execution.
......@@ -14,13 +13,15 @@ class OnlineMCTSController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super(OnlineMCTSController, self).__init__(env, low_level_policies,
super(MCTSController, self).__init__(env, low_level_policies,
start_node_alias)
self.curr_node_alias = start_node_alias
self.controller_args_defaults = {
"predictor": None,
"max_depth": 5, # MCTS depth
"nb_traversals": 30, # MCTS traversals before decision
"max_depth": 10, # MCTS depth
"nb_traversals": 100, # MCTS traversals before decision
"debug": False,
"rollout_timeout": 500
}
def set_current_node(self, node_alias):
......@@ -30,11 +31,21 @@ class OnlineMCTSController(ControllerBase):
self.env.set_ego_info_text(node_alias)
def change_low_level_references(self, env_copy):
# Create a copy of the environment and change references in low level policies.
"""Change references in low level policies by updating the environment
with the copy of the environment.
Args:
env_copy: reference to copy of the environment
"""
self.env = env_copy
for policy in self.low_level_policies.values():
policy.env = env_copy
def check_env(self, x):
"""Prints the object id of the environment. Debugging function."""
print('%s: self.env is %s' % (x, str(id(self.env))))
def can_transition(self):
return not self.env.is_terminal()
......@@ -45,29 +56,54 @@ class OnlineMCTSController(ControllerBase):
"predictor is not set. Use set_controller_args().")
# Store the env at this point
orig_env = self.env
# self.check_env('i')
np.random.seed()
# Change low level references before init MCTSLearner instance
env_before_mcts = orig_env.copy()
self.change_low_level_references(env_before_mcts)
print('Current Node: %s' % self.curr_node_alias)
mcts = MCTSLearner(self.env, self.low_level_policies,
self.curr_node_alias)
mcts.max_depth = self.max_depth
mcts.set_controller_args(predictor=self.predictor)
# self.check_env('b4')
# print('Current Node: %s' % self.curr_node_alias)
if not hasattr(self, 'mcts'):
if self.debug:
print('Creating MCTS Tree: max depth {}'.format(self.max_depth))
self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth,
debug=self.debug, rollout_timeout=self.rollout_timeout)
self.mcts.set_controller_args(predictor=self.predictor)
if self.debug:
print('')
# Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for num_epoch in range(
self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
mcts.curr_node_num = 0
num_epoch = 0
if not self.debug:
progress = tqdm.tqdm(total=self.nb_traversals-self.mcts.tree.root.N)
while num_epoch < self.nb_traversals: # tqdm
if self.mcts.tree.root.N >= self.nb_traversals:
break
env_begin_epoch = env_before_mcts.copy()
self.change_low_level_references(env_begin_epoch)
# self.check_env('e%d' % num_epoch)
init_obs = self.env.get_features_tuple()
v, all_ep_R = mcts.traverse(init_obs)
self.mcts.env = env_begin_epoch
if self.debug:
print('Search %d: ' % num_epoch, end=' ')
success = self.mcts.search(init_obs)
num_epoch += 1
if not self.debug:
progress.update(1)
if not self.debug:
progress.close()
self.change_low_level_references(orig_env)
# self.check_env('p')
# Find the nodes from the root node
mcts.curr_node_num = 0
print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = mcts.get_best_node(
self.env.get_features_tuple(), use_ucb=False)
print('MCTS suggested next option: %s' % node_after_transition)
self.set_current_node(node_after_transition)
# print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = self.mcts.best_action(self.mcts.tree.root, 0)
if self.debug:
print('MCTS suggested next option: %s' % node_after_transition)
p = {'overall': self.mcts.tree.root.Q * 1.0 / self.mcts.tree.root.N}
for edge in self.mcts.tree.root.edges.keys():
next_node = self.mcts.tree.nodes[self.mcts.tree.root.edges[edge]]
p[edge] = next_node.Q * 1.0 / next_node.N
if self.debug:
print('Btw: %s' % str(p))
self.mcts.tree.reconstruct(node_after_transition)
self.set_current_node(node_after_transition)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
{
"nodes": {
"wait": "MCTSWait",
"follow": "MCTSFollow",
"stop": "MCTSStop",
"changelane": "MCTSChangeLane",
"keeplane": "MCTSKeepLane"
"wait": "Wait",
"follow": "Follow",
"stop": "Stop",
"changelane": "ChangeLane",
"keeplane": "KeepLane"
},
"edges": {
......
import json
import os # for the use of os.path.isfile
from .simple_intersection.maneuvers import *
from .simple_intersection.mcts_maneuvers import *
from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSController, ManualPolicy
from backends import RLController, DDPGLearner, MCTSController, ManualPolicy
class OptionsGraph:
"""Represent the options graph as a graph like structure. The configuration
......@@ -67,11 +65,8 @@ class OptionsGraph:
self.controller = ManualPolicy(self.env, self.maneuvers, self.adj,
self.start_node_alias)
elif self.config["method"] == "mcts":
self.controller = MCTSLearner(self.env, self.maneuvers,
self.start_node_alias)
elif self.config["method"] == "online_mcts":
self.controller = OnlineMCTSController(self.env, self.maneuvers,
self.start_node_alias)
self.controller = MCTSController(self.env, self.maneuvers,
self.start_node_alias)
else:
raise Exception(self.__class__.__name__ + \
"Controller to be used not specified")
......@@ -156,6 +151,7 @@ class OptionsGraph:
# TODO: error handling
def load_trained_low_level_policies(self):
for key, maneuver in self.maneuvers.items():
# TODO: Ensure that for manual policies, nothing is loaded
trained_policy_path = "backends/trained_policies/" + key + "/"
critic_file_exists = os.path.isfile(trained_policy_path + key + "_weights_critic.h5f")
actor_file_exists = os.path.isfile(trained_policy_path + key + "_weights_actor.h5f")
......@@ -180,9 +176,6 @@ class OptionsGraph:
print("\n Warning: the trained low-level policy of \"" + key +
"\" does not exists; the manual policy will be used.\n")
if self.config["method"] == "mcts":
maneuver.timeout = np.inf
def get_number_of_nodes(self):
return len(self.maneuvers)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment