Skip to content
Snippets Groups Projects
Commit 8c6ddf97 authored by Ashish Gaurav's avatar Ashish Gaurav
Browse files

Merge branch 'mcts_rollout_timeout' into 'master'

add rollout timeout for MCTS

See merge request !6
parents cb171a91 051c8502
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@ class MCTSController(ControllerBase):
"max_depth": 10, # MCTS depth
"nb_traversals": 100, # MCTS traversals before decision
"debug": False,
"rollout_timeout": 500
}
def set_current_node(self, node_alias):
......@@ -65,7 +66,8 @@ class MCTSController(ControllerBase):
if not hasattr(self, 'mcts'):
if self.debug:
print('Creating MCTS Tree: max depth {}'.format(self.max_depth))
self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth, debug=self.debug)
self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth,
debug=self.debug, rollout_timeout=self.rollout_timeout)
self.mcts.set_controller_args(predictor=self.predictor)
if self.debug:
print('')
......
......@@ -116,7 +116,8 @@ class Tree:
class MCTSLearner(ControllerBase):
"""MCTS Logic."""
def __init__(self, env, low_level_policies, max_depth=10, debug=False):
def __init__(self, env, low_level_policies, max_depth=10, debug=False,
rollout_timeout=500):
"""Constructor for MCTSLearner.
Args:
......@@ -124,6 +125,7 @@ class MCTSLearner(ControllerBase):
low_level_policies: given low level maneuvers
max_depth: the tree's max depth
debug: whether or not to print debug statements
rollout_timeout: timeout for the rollout
"""
self.env = env # super?
......@@ -131,6 +133,7 @@ class MCTSLearner(ControllerBase):
self.controller_args_defaults = {"predictor": None}
self.tree = Tree(max_depth=max_depth)
self.debug = debug
self.rollout_timeout = rollout_timeout
def reset(self):
"""Resets maneuvers and sets current node to root."""
......@@ -249,11 +252,15 @@ class MCTSLearner(ControllerBase):
# print('Reached depth %d' % self.tree.nodes[self.tree.curr_node_num].depth, end=' ')
# print('at node: %d, reached leaf: %s, terminated: %s' % (self.tree.curr_node_num, reached_leaf, self.env.is_terminal()))
if reached_leaf:
rollout_reward = self.def_policy() # from leaf node
rollout_reward, timed_out = self.def_policy() # from leaf node
if rollout_reward > 0:
self.backup(1.0) # from leaf node
success = 1
elif rollout_reward < -150:
elif rollout_reward < -150 or timed_out:
# TODO: -150 is arbitrary. It should be set from outside or
# provided through a variable, based on the env. Same for timeout,
# it is specific to the env. Also, for smaller timeouts it
# may be a good idea to propagate 0 instead of -1.
self.backup(-1.0)
else:
self.backup(0)
......@@ -353,7 +360,7 @@ class MCTSLearner(ControllerBase):
rollout_reward = 0
obs = self.tree.latest_obs
it = 0
while not self.env.is_terminal():
while (not self.env.is_terminal()) and it < self.rollout_timeout:
it += 1
possible_options = self._get_possible_options()
# print('possible is %s' % possible_options)
......@@ -376,9 +383,10 @@ class MCTSLearner(ControllerBase):
if eps_R != None:
rollout_reward += eps_R
# print('Rollout steps = %d' % it)
timed_out = (it < self.rollout_timeout)
if self.debug:
print(' <<%g>>' % rollout_reward, end=' ')
return rollout_reward
return rollout_reward, timed_out
def backup(self, rollout_reward):
"""Reward backup strategy.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment