Skip to content
Snippets Groups Projects
Commit 051c8502 authored by Ashish Gaurav's avatar Ashish Gaurav
Browse files

add rollout timeout for MCTS

parent cb171a91
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@ class MCTSController(ControllerBase):
"max_depth": 10, # MCTS depth
"nb_traversals": 100, # MCTS traversals before decision
"debug": False,
"rollout_timeout": 500
}
def set_current_node(self, node_alias):
......@@ -65,7 +66,8 @@ class MCTSController(ControllerBase):
if not hasattr(self, 'mcts'):
if self.debug:
print('Creating MCTS Tree: max depth {}'.format(self.max_depth))
self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth, debug=self.debug)
self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth,
debug=self.debug, rollout_timeout=self.rollout_timeout)
self.mcts.set_controller_args(predictor=self.predictor)
if self.debug:
print('')
......
......@@ -116,7 +116,8 @@ class Tree:
class MCTSLearner(ControllerBase):
"""MCTS Logic."""
def __init__(self, env, low_level_policies, max_depth=10, debug=False):
def __init__(self, env, low_level_policies, max_depth=10, debug=False,
rollout_timeout=500):
"""Constructor for MCTSLearner.
Args:
......@@ -124,6 +125,7 @@ class MCTSLearner(ControllerBase):
low_level_policies: given low level maneuvers
max_depth: the tree's max depth
debug: whether or not to print debug statements
rollout_timeout: timeout for the rollout
"""
self.env = env # super?
......@@ -131,6 +133,7 @@ class MCTSLearner(ControllerBase):
self.controller_args_defaults = {"predictor": None}
self.tree = Tree(max_depth=max_depth)
self.debug = debug
self.rollout_timeout = rollout_timeout
def reset(self):
"""Resets maneuvers and sets current node to root."""
......@@ -249,11 +252,15 @@ class MCTSLearner(ControllerBase):
# print('Reached depth %d' % self.tree.nodes[self.tree.curr_node_num].depth, end=' ')
# print('at node: %d, reached leaf: %s, terminated: %s' % (self.tree.curr_node_num, reached_leaf, self.env.is_terminal()))
if reached_leaf:
rollout_reward = self.def_policy() # from leaf node
rollout_reward, timed_out = self.def_policy() # from leaf node
if rollout_reward > 0:
self.backup(1.0) # from leaf node
success = 1
elif rollout_reward < -150:
elif rollout_reward < -150 or timed_out:
# TODO: -150 is arbitrary. It should be set from outside or
# provided through a variable, based on the env. Same for timeout,
# it is specific to the env. Also, for smaller timeouts it
# may be a good idea to propagate 0 instead of -1.
self.backup(-1.0)
else:
self.backup(0)
......@@ -353,7 +360,7 @@ class MCTSLearner(ControllerBase):
rollout_reward = 0
obs = self.tree.latest_obs
it = 0
while not self.env.is_terminal():
while (not self.env.is_terminal()) and it < self.rollout_timeout:
it += 1
possible_options = self._get_possible_options()
# print('possible is %s' % possible_options)
......@@ -376,9 +383,10 @@ class MCTSLearner(ControllerBase):
if eps_R != None:
rollout_reward += eps_R
# print('Rollout steps = %d' % it)
timed_out = (it < self.rollout_timeout)
if self.debug:
print(' <<%g>>' % rollout_reward, end=' ')
return rollout_reward
return rollout_reward, timed_out
def backup(self, rollout_reward):
"""Reward backup strategy.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment