diff --git a/backends/mcts_controller.py b/backends/mcts_controller.py index 619a2260a56cf433da00c4b4141fbac22c200df8..2644419b9fe8f9ba9a85969e7e9bd2069d12b123 100644 --- a/backends/mcts_controller.py +++ b/backends/mcts_controller.py @@ -21,6 +21,7 @@ class MCTSController(ControllerBase): "max_depth": 10, # MCTS depth "nb_traversals": 100, # MCTS traversals before decision "debug": False, + "rollout_timeout": 500 } def set_current_node(self, node_alias): @@ -65,7 +66,8 @@ class MCTSController(ControllerBase): if not hasattr(self, 'mcts'): if self.debug: print('Creating MCTS Tree: max depth {}'.format(self.max_depth)) - self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth, debug=self.debug) + self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth, + debug=self.debug, rollout_timeout=self.rollout_timeout) self.mcts.set_controller_args(predictor=self.predictor) if self.debug: print('') diff --git a/backends/mcts_learner.py b/backends/mcts_learner.py index 0248d6ae48e476c51197110b08c87d638b723d69..aa8ca4835293979afd4a1da043c4b662b558a13e 100644 --- a/backends/mcts_learner.py +++ b/backends/mcts_learner.py @@ -116,7 +116,8 @@ class Tree: class MCTSLearner(ControllerBase): """MCTS Logic.""" - def __init__(self, env, low_level_policies, max_depth=10, debug=False): + def __init__(self, env, low_level_policies, max_depth=10, debug=False, + rollout_timeout=500): """Constructor for MCTSLearner. Args: @@ -124,6 +125,7 @@ class MCTSLearner(ControllerBase): low_level_policies: given low level maneuvers max_depth: the tree's max depth debug: whether or not to print debug statements + rollout_timeout: timeout for the rollout """ self.env = env # super? @@ -131,6 +133,7 @@ class MCTSLearner(ControllerBase): self.controller_args_defaults = {"predictor": None} self.tree = Tree(max_depth=max_depth) self.debug = debug + self.rollout_timeout = rollout_timeout def reset(self): """Resets maneuvers and sets current node to root.""" @@ -249,11 +252,15 @@ class MCTSLearner(ControllerBase): # print('Reached depth %d' % self.tree.nodes[self.tree.curr_node_num].depth, end=' ') # print('at node: %d, reached leaf: %s, terminated: %s' % (self.tree.curr_node_num, reached_leaf, self.env.is_terminal())) if reached_leaf: - rollout_reward = self.def_policy() # from leaf node + rollout_reward, timed_out = self.def_policy() # from leaf node if rollout_reward > 0: self.backup(1.0) # from leaf node success = 1 - elif rollout_reward < -150: + elif rollout_reward < -150 or timed_out: + # TODO: -150 is arbitrary. It should be set from outside or + # provided through a variable, based on the env. Same for timeout, + # it is specific to the env. Also, for smaller timeouts it + # may be a good idea to propagate 0 instead of -1. self.backup(-1.0) else: self.backup(0) @@ -353,7 +360,7 @@ class MCTSLearner(ControllerBase): rollout_reward = 0 obs = self.tree.latest_obs it = 0 - while not self.env.is_terminal(): + while (not self.env.is_terminal()) and it < self.rollout_timeout: it += 1 possible_options = self._get_possible_options() # print('possible is %s' % possible_options) @@ -376,9 +383,10 @@ class MCTSLearner(ControllerBase): if eps_R != None: rollout_reward += eps_R # print('Rollout steps = %d' % it) + timed_out = (it < self.rollout_timeout) if self.debug: print(' <<%g>>' % rollout_reward, end=' ') - return rollout_reward + return rollout_reward, timed_out def backup(self, rollout_reward): """Reward backup strategy.