from .controller_base import ControllerBase from .mcts_learner import MCTSLearner import tqdm import numpy as np class OnlineMCTSController(ControllerBase): """Online MCTS.""" def __init__(self, env, low_level_policies, start_node_alias): """Constructor for manual policy execution. Args: env: env instance low_level_policies: low level policies dictionary """ super(OnlineMCTSController, self).__init__(env, low_level_policies, start_node_alias) self.curr_node_alias = start_node_alias self.controller_args_defaults = { "predictor": None, "max_depth": 5, # MCTS depth "nb_traversals": 30, # MCTS traversals before decision } def set_current_node(self, node_alias): self.current_node = self.low_level_policies[node_alias] self.curr_node_alias = node_alias self.current_node.reset_maneuver() self.env.set_ego_info_text(node_alias) def change_low_level_references(self, env_copy): # Create a copy of the environment and change references in low level policies. self.env = env_copy for policy in self.low_level_policies.values(): policy.env = env_copy def can_transition(self): return not self.env.is_terminal() def do_transition(self): # Require a predictor function if self.predictor is None: raise Exception(self.__class__.__name__ + \ "predictor is not set. Use set_controller_args().") # Store the env at this point orig_env = self.env np.random.seed() # Change low level references before init MCTSLearner instance env_before_mcts = orig_env.copy() self.change_low_level_references(env_before_mcts) print('Current Node: %s' % self.curr_node_alias) mcts = MCTSLearner(self.env, self.low_level_policies, self.curr_node_alias) mcts.max_depth = self.max_depth mcts.set_controller_args(predictor=self.predictor) # Do nb_traversals number of traversals, reset env to this point every time # print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals)) for num_epoch in range( self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)): mcts.curr_node_num = 0 env_begin_epoch = env_before_mcts.copy() self.change_low_level_references(env_begin_epoch) init_obs = self.env.get_features_tuple() v, all_ep_R = mcts.traverse(init_obs) self.change_low_level_references(orig_env) # Find the nodes from the root node mcts.curr_node_num = 0 print('%s' % mcts._to_discrete(self.env.get_features_tuple())) node_after_transition = mcts.get_best_node( self.env.get_features_tuple(), use_ucb=False) print('MCTS suggested next option: %s' % node_after_transition) self.set_current_node(node_after_transition)