mcts_controller.py 4.4 KB
Newer Older
Ashish Gaurav's avatar
Ashish Gaurav committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
from .controller_base import ControllerBase
from .mcts_learner import MCTSLearner
import tqdm
import numpy as np

class MCTSController(ControllerBase):
    """MCTS Controller."""

    def __init__(self, env, low_level_policies, start_node_alias):
        """Constructor for manual policy execution.

        Args:
            env: env instance
            low_level_policies: low level policies dictionary
        """
        super(MCTSController, self).__init__(env, low_level_policies,
                                                   start_node_alias)
        self.curr_node_alias = start_node_alias
        self.controller_args_defaults = {
            "predictor": None,
            "max_depth": 10,  # MCTS depth
            "nb_traversals": 100,  # MCTS traversals before decision
            "debug": False,
        }

    def set_current_node(self, node_alias):
        self.current_node = self.low_level_policies[node_alias]
        self.curr_node_alias = node_alias
        self.current_node.reset_maneuver()
        self.env.set_ego_info_text(node_alias)

    def change_low_level_references(self, env_copy):
        """Change references in low level policies by updating the environment
        with the copy of the environment.

        Args:
            env_copy: reference to copy of the environment
        """

        self.env = env_copy
        for policy in self.low_level_policies.values():
            policy.env = env_copy

    def check_env(self, x):
        """Prints the object id of the environment. Debugging function."""
        print('%s: self.env is %s' % (x, str(id(self.env))))

    def can_transition(self):
        return not self.env.is_terminal()

    def do_transition(self):
        # Require a predictor function
        if self.predictor is None:
            raise Exception(self.__class__.__name__ + \
                                 "predictor is not set. Use set_controller_args().")
        # Store the env at this point
        orig_env = self.env
        # self.check_env('i')
        np.random.seed()
        # Change low level references before init MCTSLearner instance
        env_before_mcts = orig_env.copy()
        self.change_low_level_references(env_before_mcts)
        # self.check_env('b4')
        # print('Current Node: %s' % self.curr_node_alias)
        if not hasattr(self, 'mcts'):
            if self.debug:
                print('Creating MCTS Tree: max depth {}'.format(self.max_depth))
            self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth, debug=self.debug)
            self.mcts.set_controller_args(predictor=self.predictor)
        if self.debug: 
            print('')
        # Do nb_traversals number of traversals, reset env to this point every time
        # print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
        num_epoch = 0
        if not self.debug:
            progress = tqdm.tqdm(total=self.nb_traversals-self.mcts.tree.root.N)
        while num_epoch < self.nb_traversals: # tqdm
            if self.mcts.tree.root.N >= self.nb_traversals:
                break
            env_begin_epoch = env_before_mcts.copy()
            self.change_low_level_references(env_begin_epoch)
            # self.check_env('e%d' % num_epoch)
            init_obs = self.env.get_features_tuple()
            self.mcts.env = env_begin_epoch
            if self.debug:
                print('Search %d: ' % num_epoch, end=' ')
            success = self.mcts.search(init_obs)
            num_epoch += 1
            if not self.debug:
                progress.update(1)
        if not self.debug:
            progress.close()
        self.change_low_level_references(orig_env)
        # self.check_env('p')
        # Find the nodes from the root node
        # print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
        node_after_transition = self.mcts.best_action(self.mcts.tree.root, 0)
        if self.debug:
            print('MCTS suggested next option: %s' % node_after_transition)
        p = {'overall': self.mcts.tree.root.Q * 1.0 / self.mcts.tree.root.N}
        for edge in self.mcts.tree.root.edges.keys():
            next_node = self.mcts.tree.nodes[self.mcts.tree.root.edges[edge]]
            p[edge] = next_node.Q * 1.0 / next_node.N
        if self.debug:
            print('Btw: %s' % str(p))
        self.mcts.tree.reconstruct(node_after_transition)
        self.set_current_node(node_after_transition)