online_mcts_controller.py 3.08 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5
from .controller_base import ControllerBase
from .mcts_learner import MCTSLearner
import tqdm
import numpy as np

Ashish Gaurav's avatar
Ashish Gaurav committed
6

Aravind Bk's avatar
Aravind Bk committed
7
class OnlineMCTSController(ControllerBase):
Ashish Gaurav's avatar
Ashish Gaurav committed
8
    """Online MCTS."""
Aravind Bk's avatar
Aravind Bk committed
9 10 11 12 13 14 15 16

    def __init__(self, env, low_level_policies, start_node_alias):
        """Constructor for manual policy execution.

        Args:
            env: env instance
            low_level_policies: low level policies dictionary
        """
Ashish Gaurav's avatar
Ashish Gaurav committed
17 18
        super(OnlineMCTSController, self).__init__(env, low_level_policies,
                                                   start_node_alias)
Aravind Bk's avatar
Aravind Bk committed
19 20 21
        self.curr_node_alias = start_node_alias
        self.controller_args_defaults = {
            "predictor": None,
Ashish Gaurav's avatar
Ashish Gaurav committed
22 23
            "max_depth": 5,  # MCTS depth
            "nb_traversals": 30,  # MCTS traversals before decision
Aravind Bk's avatar
Aravind Bk committed
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
        }

    def set_current_node(self, node_alias):
        self.current_node = self.low_level_policies[node_alias]
        self.curr_node_alias = node_alias
        self.current_node.reset_maneuver()
        self.env.set_ego_info_text(node_alias)

    def change_low_level_references(self, env_copy):
        # Create a copy of the environment and change references in low level policies.
        self.env = env_copy
        for policy in self.low_level_policies.values():
            policy.env = env_copy

    def can_transition(self):
        return not self.env.is_terminal()

    def do_transition(self):
        # Require a predictor function
        if self.predictor is None:
            raise Exception(self.__class__.__name__ + \
                                 "predictor is not set. Use set_controller_args().")
        # Store the env at this point
        orig_env = self.env
        np.random.seed()
        # Change low level references before init MCTSLearner instance
        env_before_mcts = orig_env.copy()
        self.change_low_level_references(env_before_mcts)
        print('Current Node: %s' % self.curr_node_alias)
Ashish Gaurav's avatar
Ashish Gaurav committed
53 54
        mcts = MCTSLearner(self.env, self.low_level_policies,
                           self.curr_node_alias)
Aravind Bk's avatar
Aravind Bk committed
55
        mcts.max_depth = self.max_depth
Ashish Gaurav's avatar
Ashish Gaurav committed
56
        mcts.set_controller_args(predictor=self.predictor)
Aravind Bk's avatar
Aravind Bk committed
57 58
        # Do nb_traversals number of traversals, reset env to this point every time
        # print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
Ashish Gaurav's avatar
Ashish Gaurav committed
59 60
        for num_epoch in range(
                self.nb_traversals):  # tqdm.tqdm(range(self.nb_traversals)):
Aravind Bk's avatar
Aravind Bk committed
61 62 63 64 65 66 67 68 69
            mcts.curr_node_num = 0
            env_begin_epoch = env_before_mcts.copy()
            self.change_low_level_references(env_begin_epoch)
            init_obs = self.env.get_features_tuple()
            v, all_ep_R = mcts.traverse(init_obs)
        self.change_low_level_references(orig_env)
        # Find the nodes from the root node
        mcts.curr_node_num = 0
        print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
Ashish Gaurav's avatar
Ashish Gaurav committed
70 71
        node_after_transition = mcts.get_best_node(
            self.env.get_features_tuple(), use_ucb=False)
Aravind Bk's avatar
Aravind Bk committed
72
        print('MCTS suggested next option: %s' % node_after_transition)
Ashish Gaurav's avatar
Ashish Gaurav committed
73
        self.set_current_node(node_after_transition)