From 051c85024ed6f97c604c79b4c3088ea0b072ad40 Mon Sep 17 00:00:00 2001 From: Ashish Gaurav <agaurav77@yahoo.com> Date: Wed, 6 Feb 2019 15:23:58 -0500 Subject: [PATCH] add rollout timeout for MCTS --- backends/mcts_controller.py | 4 +++- backends/mcts_learner.py | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/backends/mcts_controller.py b/backends/mcts_controller.py index 619a226..2644419 100644 --- a/backends/mcts_controller.py +++ b/backends/mcts_controller.py @@ -21,6 +21,7 @@ class MCTSController(ControllerBase): "max_depth": 10, # MCTS depth "nb_traversals": 100, # MCTS traversals before decision "debug": False, + "rollout_timeout": 500 } def set_current_node(self, node_alias): @@ -65,7 +66,8 @@ class MCTSController(ControllerBase): if not hasattr(self, 'mcts'): if self.debug: print('Creating MCTS Tree: max depth {}'.format(self.max_depth)) - self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth, debug=self.debug) + self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth, + debug=self.debug, rollout_timeout=self.rollout_timeout) self.mcts.set_controller_args(predictor=self.predictor) if self.debug: print('') diff --git a/backends/mcts_learner.py b/backends/mcts_learner.py index 0248d6a..aa8ca48 100644 --- a/backends/mcts_learner.py +++ b/backends/mcts_learner.py @@ -116,7 +116,8 @@ class Tree: class MCTSLearner(ControllerBase): """MCTS Logic.""" - def __init__(self, env, low_level_policies, max_depth=10, debug=False): + def __init__(self, env, low_level_policies, max_depth=10, debug=False, + rollout_timeout=500): """Constructor for MCTSLearner. Args: @@ -124,6 +125,7 @@ class MCTSLearner(ControllerBase): low_level_policies: given low level maneuvers max_depth: the tree's max depth debug: whether or not to print debug statements + rollout_timeout: timeout for the rollout """ self.env = env # super? @@ -131,6 +133,7 @@ class MCTSLearner(ControllerBase): self.controller_args_defaults = {"predictor": None} self.tree = Tree(max_depth=max_depth) self.debug = debug + self.rollout_timeout = rollout_timeout def reset(self): """Resets maneuvers and sets current node to root.""" @@ -249,11 +252,15 @@ class MCTSLearner(ControllerBase): # print('Reached depth %d' % self.tree.nodes[self.tree.curr_node_num].depth, end=' ') # print('at node: %d, reached leaf: %s, terminated: %s' % (self.tree.curr_node_num, reached_leaf, self.env.is_terminal())) if reached_leaf: - rollout_reward = self.def_policy() # from leaf node + rollout_reward, timed_out = self.def_policy() # from leaf node if rollout_reward > 0: self.backup(1.0) # from leaf node success = 1 - elif rollout_reward < -150: + elif rollout_reward < -150 or timed_out: + # TODO: -150 is arbitrary. It should be set from outside or + # provided through a variable, based on the env. Same for timeout, + # it is specific to the env. Also, for smaller timeouts it + # may be a good idea to propagate 0 instead of -1. self.backup(-1.0) else: self.backup(0) @@ -353,7 +360,7 @@ class MCTSLearner(ControllerBase): rollout_reward = 0 obs = self.tree.latest_obs it = 0 - while not self.env.is_terminal(): + while (not self.env.is_terminal()) and it < self.rollout_timeout: it += 1 possible_options = self._get_possible_options() # print('possible is %s' % possible_options) @@ -376,9 +383,10 @@ class MCTSLearner(ControllerBase): if eps_R != None: rollout_reward += eps_R # print('Rollout steps = %d' % it) + timed_out = (it < self.rollout_timeout) if self.debug: print(' <<%g>>' % rollout_reward, end=' ') - return rollout_reward + return rollout_reward, timed_out def backup(self, rollout_reward): """Reward backup strategy. -- GitLab