online_mcts_controller.py 2.97 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
from .controller_base import ControllerBase
from .mcts_learner import MCTSLearner
import tqdm
import numpy as np

class OnlineMCTSController(ControllerBase):
    """Online MCTS"""

    def __init__(self, env, low_level_policies, start_node_alias):
        """Constructor for manual policy execution.

        Args:
            env: env instance
            low_level_policies: low level policies dictionary
        """
        super(OnlineMCTSController, self).__init__(env, low_level_policies, start_node_alias)
        self.curr_node_alias = start_node_alias
        self.controller_args_defaults = {
            "predictor": None,
            "max_depth": 5, # MCTS depth
            "nb_traversals": 30, # MCTS traversals before decision
        }

    def set_current_node(self, node_alias):
        self.current_node = self.low_level_policies[node_alias]
        self.curr_node_alias = node_alias
        self.current_node.reset_maneuver()
        self.env.set_ego_info_text(node_alias)

    def change_low_level_references(self, env_copy):
        # Create a copy of the environment and change references in low level policies.
        self.env = env_copy
        for policy in self.low_level_policies.values():
            policy.env = env_copy

    def can_transition(self):
        return not self.env.is_terminal()

    def do_transition(self):
        # Require a predictor function
        if self.predictor is None:
            raise Exception(self.__class__.__name__ + \
                                 "predictor is not set. Use set_controller_args().")
        # Store the env at this point
        orig_env = self.env
        np.random.seed()
        # Change low level references before init MCTSLearner instance
        env_before_mcts = orig_env.copy()
        self.change_low_level_references(env_before_mcts)
        print('Current Node: %s' % self.curr_node_alias)
        mcts = MCTSLearner(self.env, self.low_level_policies, self.curr_node_alias)
        mcts.max_depth = self.max_depth
        mcts.set_controller_args(predictor = self.predictor)
        # Do nb_traversals number of traversals, reset env to this point every time
        # print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
        for num_epoch in range(self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
            mcts.curr_node_num = 0
            env_begin_epoch = env_before_mcts.copy()
            self.change_low_level_references(env_begin_epoch)
            init_obs = self.env.get_features_tuple()
            v, all_ep_R = mcts.traverse(init_obs)
        self.change_low_level_references(orig_env)
        # Find the nodes from the root node
        mcts.curr_node_num = 0
        print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
        node_after_transition = mcts.get_best_node(self.env.get_features_tuple(), use_ucb=False)
        print('MCTS suggested next option: %s' % node_after_transition)
        self.set_current_node(node_after_transition)