Commit cb171a91 authored by Ashish Gaurav's avatar Ashish Gaurav

MCTS Fixes

* Reimplemented UCT MCTS
* Fixed softmax
* Merged multiple branches into this branch, all of which should now be in master
* Added reuse of tree functionality
* Added the ability to expand nodes based on q values rather than at random
* Refactored everything, deleted non necessary MCTS classes and files, and mcts.py can evaluate newer MCTS
parent 062ad4ff
......@@ -2,4 +2,4 @@ from .manual_policy import ManualPolicy
from .mcts_learner import MCTSLearner
from .rl_controller import RLController
from .kerasrl_learner import DDPGLearner, DQNLearner
from .online_mcts_controller import OnlineMCTSController
\ No newline at end of file
from .mcts_controller import MCTSController
\ No newline at end of file
......@@ -13,7 +13,7 @@ from rl.policy import GreedyQPolicy, EpsGreedyQPolicy, MaxBoltzmannQPolicy
from rl.callbacks import ModelIntervalCheckpoint
import numpy as np
import copy
class DDPGLearner(LearnerBase):
def __init__(self,
......@@ -270,17 +270,17 @@ class DQNLearner(LearnerBase):
Returns: Keras Model object of actor
"""
model = Sequential()
model.add(Flatten(input_shape=(1, ) + self.input_shape))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='tanh'))
model.add(Dense(self.nb_actions))
model.add(Activation('linear'))
# print(model.summary())
print(model.summary())
return model
......@@ -416,9 +416,14 @@ class DQNLearner(LearnerBase):
action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
q_values = self.agent_model.get_modified_q_values(observation)
max_q_value = np.abs(np.max(q_values))
q_values = [np.exp(q_value / max_q_value) for q_value in q_values]
# print('softq q_values are %s' % dict(zip(self.agent_model.low_level_policy_aliases, q_values)))
# oq_values = copy.copy(q_values)
if q_values[action_num] == -np.inf:
return 0
max_q_value = np.max(q_values)
q_values = [np.exp(q_value - max_q_value) for q_value in q_values]
relevant = q_values[action_num] / np.sum(q_values)
# print('softq: %s -> %s' % (oq_values, relevant))
return relevant
......@@ -543,6 +548,7 @@ class DQNAgentOverOptions(DQNAgent):
self.recent_observation = observation
self.recent_action = action
# print('forward gives %s from %s' % (action, dict(zip(self.low_level_policy_aliases, q_values))))
return action
def get_modified_q_values(self, observation):
......@@ -555,4 +561,4 @@ class DQNAgentOverOptions(DQNAgent):
for node_index in invalid_node_indices:
q_values[node_index] = -np.inf
return q_values
\ No newline at end of file
return q_values
......@@ -3,9 +3,8 @@ from .mcts_learner import MCTSLearner
import tqdm
import numpy as np
class OnlineMCTSController(ControllerBase):
"""Online MCTS."""
class MCTSController(ControllerBase):
"""MCTS Controller."""
def __init__(self, env, low_level_policies, start_node_alias):
"""Constructor for manual policy execution.
......@@ -14,13 +13,14 @@ class OnlineMCTSController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super(OnlineMCTSController, self).__init__(env, low_level_policies,
super(MCTSController, self).__init__(env, low_level_policies,
start_node_alias)
self.curr_node_alias = start_node_alias
self.controller_args_defaults = {
"predictor": None,
"max_depth": 5, # MCTS depth
"nb_traversals": 30, # MCTS traversals before decision
"max_depth": 10, # MCTS depth
"nb_traversals": 100, # MCTS traversals before decision
"debug": False,
}
def set_current_node(self, node_alias):
......@@ -30,11 +30,21 @@ class OnlineMCTSController(ControllerBase):
self.env.set_ego_info_text(node_alias)
def change_low_level_references(self, env_copy):
# Create a copy of the environment and change references in low level policies.
"""Change references in low level policies by updating the environment
with the copy of the environment.
Args:
env_copy: reference to copy of the environment
"""
self.env = env_copy
for policy in self.low_level_policies.values():
policy.env = env_copy
def check_env(self, x):
"""Prints the object id of the environment. Debugging function."""
print('%s: self.env is %s' % (x, str(id(self.env))))
def can_transition(self):
return not self.env.is_terminal()
......@@ -45,29 +55,53 @@ class OnlineMCTSController(ControllerBase):
"predictor is not set. Use set_controller_args().")
# Store the env at this point
orig_env = self.env
# self.check_env('i')
np.random.seed()
# Change low level references before init MCTSLearner instance
env_before_mcts = orig_env.copy()
self.change_low_level_references(env_before_mcts)
print('Current Node: %s' % self.curr_node_alias)
mcts = MCTSLearner(self.env, self.low_level_policies,
self.curr_node_alias)
mcts.max_depth = self.max_depth
mcts.set_controller_args(predictor=self.predictor)
# self.check_env('b4')
# print('Current Node: %s' % self.curr_node_alias)
if not hasattr(self, 'mcts'):
if self.debug:
print('Creating MCTS Tree: max depth {}'.format(self.max_depth))
self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth, debug=self.debug)
self.mcts.set_controller_args(predictor=self.predictor)
if self.debug:
print('')
# Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for num_epoch in range(
self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
mcts.curr_node_num = 0
num_epoch = 0
if not self.debug:
progress = tqdm.tqdm(total=self.nb_traversals-self.mcts.tree.root.N)
while num_epoch < self.nb_traversals: # tqdm
if self.mcts.tree.root.N >= self.nb_traversals:
break
env_begin_epoch = env_before_mcts.copy()
self.change_low_level_references(env_begin_epoch)
# self.check_env('e%d' % num_epoch)
init_obs = self.env.get_features_tuple()
v, all_ep_R = mcts.traverse(init_obs)
self.mcts.env = env_begin_epoch
if self.debug:
print('Search %d: ' % num_epoch, end=' ')
success = self.mcts.search(init_obs)
num_epoch += 1
if not self.debug:
progress.update(1)
if not self.debug:
progress.close()
self.change_low_level_references(orig_env)
# self.check_env('p')
# Find the nodes from the root node
mcts.curr_node_num = 0
print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = mcts.get_best_node(
self.env.get_features_tuple(), use_ucb=False)
print('MCTS suggested next option: %s' % node_after_transition)
self.set_current_node(node_after_transition)
# print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = self.mcts.best_action(self.mcts.tree.root, 0)
if self.debug:
print('MCTS suggested next option: %s' % node_after_transition)
p = {'overall': self.mcts.tree.root.Q * 1.0 / self.mcts.tree.root.N}
for edge in self.mcts.tree.root.edges.keys():
next_node = self.mcts.tree.nodes[self.mcts.tree.root.edges[edge]]
p[edge] = next_node.Q * 1.0 / next_node.N
if self.debug:
print('Btw: %s' % str(p))
self.mcts.tree.reconstruct(node_after_transition)
self.set_current_node(node_after_transition)
\ No newline at end of file
......@@ -2,116 +2,196 @@ from .controller_base import ControllerBase
import numpy as np
import pickle
# TODO: Clean up debug comments
class MCTSLearner(ControllerBase):
"""Monte Carlo Tree Search implementation using the UCB1 and progressive
widening approach as explained in Paxton et al (2017)."""
_ucb_vals = set()
class Node:
"""Represents a node in a tree."""
def __init__(self, env, low_level_policies, start_node_alias,
max_depth=10):
"""Constructor for MCTSLearner.
def __init__(self, node_num):
"""Initialize a node.
Args:
env: env instance
low_level_policies: low level policies dictionary
start_node: starting node
predictor: P(s, o) learner class; forward pass should
return the entire value from state s and option o
max_depth: max depth of the MCTS tree; default 10 levels
node_num: the unique number of a node in the tree
"""
super(MCTSLearner, self).__init__(env, low_level_policies,
start_node_alias)
self.num = node_num
self.state = ''
self.cstate = None
self.edges = {}
self.parent_num = None
self.N = 0
self.Q = 0
self.depth = None
def is_terminal(self):
"""Check whether this node is a leaf node or not."""
return self.edges == {}
class Tree:
"""Tree representation used for MCTS."""
def __init__(self, max_depth):
"""Constructor for Tree.
Args:
max_depth: max possible distance between the root node
and any leaf node.
"""
self.controller_args_defaults = {
"predictor":
None #P(s, o) learner class; forward pass should return the entire value from state s and option o
}
self.max_depth = max_depth
#: store current node alias
self.curr_node_alias = start_node_alias
#: store current node's id
self.curr_node_num = 0
#: when a new node is created, it has this id;
# afterwards this number is updated
self.new_node_num = 0
#: visitation count of discrete observations
self.N = {}
#: visitation count of discrete observation with option
self.M = {}
#: total reward from given discrete observation with option
self.TR = {}
#: node properties
self.nodes = {}
#: adjacency list
self.adj = {}
# populate root node
root_node_num, root_node_info = self._create_node(self.curr_node_alias)
self.nodes[root_node_num] = root_node_info
self.adj[root_node_num] = set() # no children
def save_model(self, file_name="mcts.pickle"):
to_backup = {
'N': self.N,
'M': self.M,
'TR': self.TR,
'nodes': self.nodes,
'adj': self.adj,
'new_node_num': self.new_node_num
}
with open(file_name, 'wb') as handle:
pickle.dump(to_backup, handle, protocol=pickle.HIGHEST_PROTOCOL)
def load_model(self, file_name='mcts.pickle'):
with open(file_name, 'rb') as handle:
to_restore = pickle.load(handle)
self.N = to_restore['N']
self.M = to_restore['M']
self.TR = to_restore['TR']
self.nodes = to_restore['nodes']
self.adj = to_restore['adj']
self.new_node_num = to_restore['new_node_num']
def _create_node(self, low_level_policy):
"""Create the node associated with curr_node_num, using the given low
level policy.
self.new_node_num = 0
self.root = self._create_node()
self.root_node_num = 0
self.root.depth = 0
self.curr_node_num = 0
self.max_depth = max_depth
self.latest_obs = None
self.latest_dis_obs = None
def _create_node(self):
"""Internal function to create a node, without edges."""
created_node = Node(self.new_node_num)
self.nodes[self.new_node_num] = created_node
self.new_node_num += 1
return created_node
def new_node(self, option_alias):
"""Creates a new node that is a child of the current node.
The new node can be reached by the option_alias edge from
the current node.
Args:
low_level_policy: the option's alias
option_alias: the edge between current node and new node
"""
created_node = self._create_node()
self.nodes[self.curr_node_num].edges[option_alias] = created_node.num
created_node.parent_num = self.curr_node_num
created_node.depth = self.nodes[self.curr_node_num].depth+1
Returns new node id, and dict with all the node properties of
the created node.
def move(self, option_alias):
"""Use the edge option_alias and move from current node to
a next node.
Args:
option_alias: edge to move along
"""
# print('Creating new node %d:%s' % (self.new_node_num, low_level_policy))
created_node_num = self.new_node_num
self.new_node_num += 1
return created_node_num, {"policy": low_level_policy}
possible_edges = self.nodes[self.curr_node_num].edges
assert(option_alias in possible_edges.keys())
self.curr_node_num = possible_edges[option_alias]
def add_state(self, obs, dis_obs):
"""Associates observation and discrete observation to the current
node. Useful to keep track of the last known observation(s).
Args:
obs: observation to save to current node's cstate
dis_obs: observation to save to current node's state
"""
if self.nodes[self.curr_node_num].state == '':
self.nodes[self.curr_node_num].state = dis_obs
self.nodes[self.curr_node_num].cstate = obs
# assert(self.nodes[self.curr_node_num].state == dis_obs)
# assert on the continuous obs?
self.latest_obs = obs
self.latest_dis_obs = dis_obs
def reconstruct(self, option_alias):
"""Use the option_alias from the root node and reposition the tree
such that the new root node is the node reached by following option_alias
from the current root node.
Args:
option_alias: edge to follow from the root node
"""
new_root_num = self.nodes[self.root_node_num].edges[option_alias]
self.root_node_num = new_root_num
self.root = self.nodes[self.root_node_num]
self.curr_node_num = self.root_node_num
class MCTSLearner(ControllerBase):
"""MCTS Logic."""
def __init__(self, env, low_level_policies, max_depth=10, debug=False):
"""Constructor for MCTSLearner.
Args:
env: environment
low_level_policies: given low level maneuvers
max_depth: the tree's max depth
debug: whether or not to print debug statements
"""
self.env = env # super?
self.low_level_policies = low_level_policies # super?
self.controller_args_defaults = {"predictor": None}
self.tree = Tree(max_depth=max_depth)
self.debug = debug
def reset(self):
"""Resets maneuvers and sets current node to root."""
all_options = set(self.low_level_policies.keys())
for option_alias in all_options:
self.low_level_policies[option_alias].reset_maneuver()
self.tree.curr_node_num = self.tree.root_node_num
def _get_possible_options(self):
"""Return all option_alias whose init condition is satisfied from
the current node."""
all_options = set(self.low_level_policies.keys())
filtered_options = set()
for option_alias in all_options:
self.low_level_policies[option_alias].reset_maneuver()
if self.low_level_policies[option_alias].initiation_condition:
filtered_options.add(option_alias)
return filtered_options
def _get_sorted_possible_options(self, node):
"""Return all option_alias whose init condition is satisfied from
the given node. The options are returned in decreasing order of
preference, which is determined from the q values given by the
predictor.
Args:
node: the node from which the possible options are to be found
"""
possible_options = self._get_possible_options()
q_values = {}
for option_alias in possible_options:
q_values[option_alias] = self.predictor(node.cstate, option_alias)
# print('to sort %s' % q_values)
sorted_list = sorted(possible_options, key=lambda z: q_values[z])
# print('sorted: %s' % sorted_list)
return sorted_list[::-1]
def _to_discrete(self, observation):
"""Converts observation to a discrete observation tuple. Also append
(a) whether we are following a vehicle, and (b) whether there is a
vehicle in the opposite lane in the approximately the same x position.
These values will be useful for Follow and ChangeLane maneuvers.
"""Converts a given observation to discrete form. Since the
discrete features don't capture all the information in the
environment, two additional features are added.
Args:
observation: observation tuple from the environment
observation: observation to discretize
Returns the discrete observation
Returns a string of 0s and 1s
"""
dis_observation = ''
for item in observation[12:20]:
if type(item) == bool:
dis_observation += '1' if item is True else '0'
if type(item) == int and item in [0, 1]:
dis_observation += str(item)
env = self.current_node.env
assert(len(dis_observation) == 8)
# Are we following a vehicle?
target_veh_i, V2V_dist = env.get_V2V_distance()
target_veh_i, V2V_dist = self.env.get_V2V_distance()
if target_veh_i is not None and V2V_dist <= 30:
dis_observation += '1'
else:
......@@ -120,9 +200,9 @@ class MCTSLearner(ControllerBase):
# Is there a vehicle in the opposite lane in approximately the
# same x position?
delta_x = 15
ego = env.vehs[0]
ego = self.env.vehs[0]
possible_collision = '0'
for veh in env.vehs[1:]:
for veh in self.env.vehs[1:]:
abs_delta_x = abs(ego.x - veh.x)
opp_lane = ego.APs['lane'] != veh.APs['lane']
is_horizontal = veh.APs['on_route']
......@@ -133,251 +213,219 @@ class MCTSLearner(ControllerBase):
return dis_observation
def _get_visitation_count(self, observation, option=None):
"""Finds the visitation count of the discrete form of the observation.
If discrete observation not found, then inserted into self.N with value
0. Auto converts the observation into discrete form. If option is not
None, then this uses self.M instead of self.N.
def move(self, option_alias):
"""Move in the MCTS tree. This means moving in the tree, updating
state information and stepping through option_alias.
Args:
observation: observation tuple from the environment
option: option alias
Returns N(discrete form of the observation) or N(observation, option)
depending on usage
option_alias: edge or option to execute to reach a next node
"""
dis_observation = self._to_discrete(observation)
if option is None:
if dis_observation not in self.N:
self.N[dis_observation] = 0
return self.N[dis_observation]
else:
if (dis_observation, option) not in self.M:
self.M[(dis_observation, option)] = 0
return self.M[(dis_observation, option)]
def _get_q_star(self, observation, option):
"""Compute average value of discrete observation - option pair.
# take the edge option_alias and move to a new node
self.set_current_node(option_alias)
pre_x = self.env.vehs[0].x
next_obs, eps_R = self.option_step()
post_x = self.env.vehs[0].x
dis_obs = self._to_discrete(next_obs)
self.tree.move(option_alias)
self.tree.add_state(next_obs, dis_obs)
# print('moved %s: %f-> %f' % (option_alias, pre_x, post_x))
if eps_R == None:
return 0
return eps_R
def search(self, obs):
"""Perform a traversal from the root node.
Args:
observation: observation tuple from the environment
option: option alias
Returns Q_star(discrete form of the observation, option)
obs: current observation
"""
dis_observation = self._to_discrete(observation)
if (dis_observation, option) not in self.TR:
self.TR[(dis_observation, option)] = 0
return self.TR[(dis_observation, option)] / (
1 + self._get_visitation_count(observation, option))
def _get_possible_options(self):
"""Returns a set of options that can be taken from the current node.
Goes through adjacency set of current node and finds which next
nodes' initiation condition is met.
"""
all_options = set(self.low_level_policies.keys())
# Filter nodes whose initiation condition are true
filtered_options = set()
for option_alias in all_options:
self.low_level_policies[option_alias].reset_maneuver()
if self.low_level_policies[option_alias].initiation_condition:
filtered_options.add(option_alias)
return filtered_options
def _get_already_visited_options(self, node_num):
"""Returns a set of options that have already been tried at least once
from the given node.
success = 0
self.reset() # reset tree and get to root
dis_obs = self._to_discrete(obs)
self.tree.add_state(obs, dis_obs)
reached_leaf = self.tree_policy() # until we reach a leaf
# print('Reached depth %d' % self.tree.nodes[self.tree.curr_node_num].depth, end=' ')
# print('at node: %d, reached leaf: %s, terminated: %s' % (self.tree.curr_node_num, reached_leaf, self.env.is_terminal()))
if reached_leaf:
rollout_reward = self.def_policy() # from leaf node
if rollout_reward > 0:
self.backup(1.0) # from leaf node
success = 1
elif rollout_reward < -150:
self.backup(-1.0)
else:
self.backup(0)
else:
# violation happened or we terminated before reaching a leaf (weird)
if self.current_node.env.goal_achieved:
self.backup(1.0)
else:
self.backup(-1.0)
# p = {'overall': self.tree.root.Q * 1.0 / self.tree.root.N}
# for edge in self.tree.root.edges.keys():
# next_node = self.tree.nodes[self.tree.root.edges[edge]]
# p[edge] = next_node.Q * 1.0 / next_node.N
# return self.best_action(self.tree.root, 0), success, p
return success
def tree_policy(self):
"""Policy that determines how to move through the MCTS tree.
Terminates either when environment reaches a terminal state, or we
reach a leaf node."""
while not self.env.is_terminal():
node = self.tree.nodes[self.tree.curr_node_num]
# Implementation IDEA: Choose the option which has the highest q first
possible_options = self._get_sorted_possible_options(node)
# possible_options = self._get_possible_options()
# print('at node: %d' % node.num)
already_visited = set(node.edges.keys())
not_visited = []
for item in possible_options:
if item not in already_visited:
not_visited.append(item)
# print('not visited %s' % not_visited)
if len(not_visited) > 0 and node.depth+1-self.tree.root.depth < self.tree.max_depth and np.power(node.N, 0.5) >= len(already_visited):
self.expand(node, not_visited)
return True
elif node.depth+1-self.tree.root.depth >= self.tree.max_depth:
return True # cant go beyond, just do a rollout
else:
option_alias = self.best_action(node, 1)
if self.debug:
print("%s" % option_alias[:2], end=',')
eps_R = self.move(option_alias)
if eps_R < -150: # violation
return False
return False
def expand(self, node, not_visited):
"""Create a new node from the given node. Chooses an option from the
not_visited list. Also moves to the newly created node.
Args:
node_num: node number of the node from which the return set has to
be computed
Returns a set of aliases of already visited options from node_num
node: node to expand from
not_visited: new possible edges or options
"""
visited_aliases = set()
for already_existing_node_num in self.adj[node_num]:
already_existing_node = self.nodes[already_existing_node_num]
node_alias = already_existing_node["policy"]
visited_aliases.add(node_alias)
return visited_aliases
# Implementation IDEA: Choose the option which has the highest q first
random_option_alias = list(not_visited)[0] # np.random.choice(list(not_visited))
if self.debug:
print("expand(%s)"%random_option_alias[:2], end=',')
self.tree.new_node(random_option_alias)
self.move(random_option_alias)
def _ucb_adjusted_q(self, observation, C=1):
"""Computes Q_star(observation, option_i) plus the UCB term, which is
C*[predictor(observation, option_i)]/[1+N(observation, option_i)], for
all option_i in the adjacency set of the current node.
def best_action(self, node, C):
"""Find the best option to execute from a given node. The constant
C determines the coefficient of the uncertainty estimate.
Args:
observation: observation tuple from the environment
C: UCB constant, experimentally defined
node: node from which the best action needs to be found
C: coefficient of uncertainty term
Returns Q values for next nodes
Returns the best possible option alias from given node.
"""
Q = {}
Q1, Q2 = {}, {} # debug
dis_observation = self._to_discrete(observation)
next_option_nums = self.adj[self.curr_node_num]
for next_option_num in next_option_nums:
next_option = self.nodes[next_option_num]["policy"]
Q1[next_option] = (
self._get_q_star(observation, next_option) + 200) / 400
Q[(dis_observation, next_option)] = \
Q1[next_option]
Q2[next_option] = C * \
(self.predictor(observation, next_option)) / \
(1 + self._get_visitation_count(observation, next_option))
self._ucb_vals.add(Q2[next_option])
Q[(dis_observation, next_option)] += \
Q2[next_option]
return Q, Q1, Q2
def _value(self, observation):
"""Computes the value function v(s) by computing sum of all
TR[dis_observation, option_i] divided by N[dis_observation]
next_options = list(node.edges.keys())
Q_UCB = {}
for option_alias in next_options:
next_node_num = node.edges[option_alias]
next_node = self.tree.nodes[next_node_num]
Q_UCB[option_alias] = 0
Q_UCB[option_alias] += next_node.Q / (next_node.N + 1)
obs = self.tree.latest_obs
pred = self.predictor(obs, option_alias)
if not np.isnan(pred):