mcts_learner.py 14.7 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4
from .controller_base import ControllerBase
import numpy as np
import pickle

Ashish Gaurav's avatar
Ashish Gaurav committed
5

Aravind Bk's avatar
Aravind Bk committed
6
class MCTSLearner(ControllerBase):
Ashish Gaurav's avatar
Ashish Gaurav committed
7 8
    """Monte Carlo Tree Search implementation using the UCB1 and progressive
    widening approach as explained in Paxton et al (2017)."""
Aravind Bk's avatar
Aravind Bk committed
9 10 11

    _ucb_vals = set()

Ashish Gaurav's avatar
Ashish Gaurav committed
12 13
    def __init__(self, env, low_level_policies, start_node_alias,
                 max_depth=10):
Aravind Bk's avatar
Aravind Bk committed
14 15 16 17 18 19 20 21 22 23 24
        """Constructor for MCTSLearner.

        Args:
            env: env instance
            low_level_policies: low level policies dictionary
            start_node: starting node
            predictor: P(s, o) learner class; forward pass should
                return the entire value from state s and option o
            max_depth: max depth of the MCTS tree; default 10 levels
        """

Ashish Gaurav's avatar
Ashish Gaurav committed
25 26
        super(MCTSLearner, self).__init__(env, low_level_policies,
                                          start_node_alias)
Aravind Bk's avatar
Aravind Bk committed
27 28

        self.controller_args_defaults = {
Ashish Gaurav's avatar
Ashish Gaurav committed
29 30
            "predictor":
            None  #P(s, o) learner class; forward pass should return the entire value from state s and option o
Aravind Bk's avatar
Aravind Bk committed
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
        }
        self.max_depth = max_depth
        #: store current node alias
        self.curr_node_alias = start_node_alias
        #: store current node's id
        self.curr_node_num = 0
        #: when a new node is created, it has this id;
        #  afterwards this number is updated
        self.new_node_num = 0
        #: visitation count of discrete observations
        self.N = {}
        #: visitation count of discrete observation with option
        self.M = {}
        #: total reward from given discrete observation with option
        self.TR = {}
        #: node properties
        self.nodes = {}
        #: adjacency list
        self.adj = {}
        # populate root node
        root_node_num, root_node_info = self._create_node(self.curr_node_alias)
        self.nodes[root_node_num] = root_node_info
Ashish Gaurav's avatar
Ashish Gaurav committed
53
        self.adj[root_node_num] = set()  # no children
Aravind Bk's avatar
Aravind Bk committed
54 55

    def save_model(self, file_name="mcts.pickle"):
Ashish Gaurav's avatar
Ashish Gaurav committed
56 57 58 59 60 61 62 63
        to_backup = {
            'N': self.N,
            'M': self.M,
            'TR': self.TR,
            'nodes': self.nodes,
            'adj': self.adj,
            'new_node_num': self.new_node_num
        }
Aravind Bk's avatar
Aravind Bk committed
64 65 66 67 68 69 70 71 72 73 74 75 76 77
        with open(file_name, 'wb') as handle:
            pickle.dump(to_backup, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def load_model(self, file_name='mcts.pickle'):
        with open(file_name, 'rb') as handle:
            to_restore = pickle.load(handle)
        self.N = to_restore['N']
        self.M = to_restore['M']
        self.TR = to_restore['TR']
        self.nodes = to_restore['nodes']
        self.adj = to_restore['adj']
        self.new_node_num = to_restore['new_node_num']

    def _create_node(self, low_level_policy):
Ashish Gaurav's avatar
Ashish Gaurav committed
78 79
        """Create the node associated with curr_node_num, using the given low
        level policy.
Aravind Bk's avatar
Aravind Bk committed
80 81 82 83 84 85 86 87 88 89 90 91 92 93

        Args:
            low_level_policy: the option's alias

        Returns new node id, and dict with all the node properties of
            the created node.
        """

        # print('Creating new node %d:%s' % (self.new_node_num, low_level_policy))
        created_node_num = self.new_node_num
        self.new_node_num += 1
        return created_node_num, {"policy": low_level_policy}

    def _to_discrete(self, observation):
Ashish Gaurav's avatar
Ashish Gaurav committed
94 95 96 97
        """Converts observation to a discrete observation tuple. Also append
        (a) whether we are following a vehicle, and (b) whether there is a
        vehicle in the opposite lane in the approximately the same x position.
        These values will be useful for Follow and ChangeLane maneuvers.
Aravind Bk's avatar
Aravind Bk committed
98 99 100 101 102 103 104 105

        Args:
            observation: observation tuple from the environment

        Returns the discrete observation
        """
        dis_observation = ''
        for item in observation[12:20]:
Ashish Gaurav's avatar
Ashish Gaurav committed
106
            if type(item) == bool:
Aravind Bk's avatar
Aravind Bk committed
107
                dis_observation += '1' if item is True else '0'
Ashish Gaurav's avatar
Ashish Gaurav committed
108
            if type(item) == int and item in [0, 1]:
Aravind Bk's avatar
Aravind Bk committed
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
                dis_observation += str(item)

        env = self.current_node.env

        # Are we following a vehicle?
        target_veh_i, V2V_dist = env.get_V2V_distance()
        if target_veh_i is not None and V2V_dist <= 30:
            dis_observation += '1'
        else:
            dis_observation += '0'

        # Is there a vehicle in the opposite lane in approximately the
        # same x position?
        delta_x = 15
        ego = env.vehs[0]
        possible_collision = '0'
        for veh in env.vehs[1:]:
            abs_delta_x = abs(ego.x - veh.x)
            opp_lane = ego.APs['lane'] != veh.APs['lane']
            is_horizontal = veh.APs['on_route']
            if abs_delta_x <= delta_x and opp_lane and is_horizontal:
                possible_collision = '1'
                break
        dis_observation += possible_collision

        return dis_observation

    def _get_visitation_count(self, observation, option=None):
        """Finds the visitation count of the discrete form of the observation.
Ashish Gaurav's avatar
Ashish Gaurav committed
138 139 140
        If discrete observation not found, then inserted into self.N with value
        0. Auto converts the observation into discrete form. If option is not
        None, then this uses self.M instead of self.N.
Aravind Bk's avatar
Aravind Bk committed
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172

        Args:
            observation: observation tuple from the environment
            option: option alias

        Returns N(discrete form of the observation) or N(observation, option)
            depending on usage
        """

        dis_observation = self._to_discrete(observation)
        if option is None:
            if dis_observation not in self.N:
                self.N[dis_observation] = 0
            return self.N[dis_observation]
        else:
            if (dis_observation, option) not in self.M:
                self.M[(dis_observation, option)] = 0
            return self.M[(dis_observation, option)]

    def _get_q_star(self, observation, option):
        """Compute average value of discrete observation - option pair.

        Args:
            observation: observation tuple from the environment
            option: option alias

        Returns Q_star(discrete form of the observation, option)
        """

        dis_observation = self._to_discrete(observation)
        if (dis_observation, option) not in self.TR:
            self.TR[(dis_observation, option)] = 0
Ashish Gaurav's avatar
Ashish Gaurav committed
173 174
        return self.TR[(dis_observation, option)] / (
            1 + self._get_visitation_count(observation, option))
Aravind Bk's avatar
Aravind Bk committed
175 176 177

    def _get_possible_options(self):
        """Returns a set of options that can be taken from the current node.
Ashish Gaurav's avatar
Ashish Gaurav committed
178 179 180

        Goes through adjacency set of current node and finds which next
        nodes' initiation condition is met.
Aravind Bk's avatar
Aravind Bk committed
181 182 183
        """

        all_options = set(self.low_level_policies.keys())
Ashish Gaurav's avatar
Ashish Gaurav committed
184

Aravind Bk's avatar
Aravind Bk committed
185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
        # Filter nodes whose initiation condition are true
        filtered_options = set()
        for option_alias in all_options:
            self.low_level_policies[option_alias].reset_maneuver()
            if self.low_level_policies[option_alias].initiation_condition:
                filtered_options.add(option_alias)

        return filtered_options

    def _get_already_visited_options(self, node_num):
        """Returns a set of options that have already been tried at least once
        from the given node.

        Args:
            node_num: node number of the node from which the return set has to
            be computed

        Returns a set of aliases of already visited options from node_num
        """

        visited_aliases = set()
        for already_existing_node_num in self.adj[node_num]:
            already_existing_node = self.nodes[already_existing_node_num]
            node_alias = already_existing_node["policy"]
            visited_aliases.add(node_alias)
        return visited_aliases

    def _ucb_adjusted_q(self, observation, C=1):
Ashish Gaurav's avatar
Ashish Gaurav committed
213 214 215
        """Computes Q_star(observation, option_i) plus the UCB term, which is
        C*[predictor(observation, option_i)]/[1+N(observation, option_i)], for
        all option_i in the adjacency set of the current node.
Aravind Bk's avatar
Aravind Bk committed
216 217 218 219 220 221 222

        Args:
            observation: observation tuple from the environment
            C: UCB constant, experimentally defined

        Returns Q values for next nodes
        """
Ashish Gaurav's avatar
Ashish Gaurav committed
223

Aravind Bk's avatar
Aravind Bk committed
224
        Q = {}
Ashish Gaurav's avatar
Ashish Gaurav committed
225
        Q1, Q2 = {}, {}  # debug
Aravind Bk's avatar
Aravind Bk committed
226 227 228 229
        dis_observation = self._to_discrete(observation)
        next_option_nums = self.adj[self.curr_node_num]
        for next_option_num in next_option_nums:
            next_option = self.nodes[next_option_num]["policy"]
Ashish Gaurav's avatar
Ashish Gaurav committed
230 231
            Q1[next_option] = (
                self._get_q_star(observation, next_option) + 200) / 400
Aravind Bk's avatar
Aravind Bk committed
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
            Q[(dis_observation, next_option)] = \
                Q1[next_option]
            Q2[next_option] = C * \
                (self.predictor(observation, next_option)) / \
                (1 + self._get_visitation_count(observation, next_option))
            self._ucb_vals.add(Q2[next_option])
            Q[(dis_observation, next_option)] += \
                Q2[next_option]
        return Q, Q1, Q2

    def _value(self, observation):
        """Computes the value function v(s) by computing sum of all
        TR[dis_observation, option_i] divided by N[dis_observation]

        Args:
            observation: observation tuple from the environment

        Returns v(s)
        """

        dis_observation = self._to_discrete(observation)
        relevant_rewards = [value for key, value in self.TR.items() \
            if key[0] == dis_observation]
        sum_rewards = np.sum(relevant_rewards)
Ashish Gaurav's avatar
Ashish Gaurav committed
256
        return sum_rewards / (1 + self._get_visitation_count(observation))
Aravind Bk's avatar
Aravind Bk committed
257 258

    def _select(self, observation, depth=0, visualize=False):
Ashish Gaurav's avatar
Ashish Gaurav committed
259 260
        """MCTS selection function. For representation, we only use the
        discrete part of the observation.
Aravind Bk's avatar
Aravind Bk committed
261 262 263 264 265

        Args:
            observation: observation tuple from the environment
            depth: current depth, starts from root node, hence 0 by default
            visualize: whether or not to visualize low level steps
Ashish Gaurav's avatar
Ashish Gaurav committed
266

Aravind Bk's avatar
Aravind Bk committed
267 268 269 270 271 272 273 274 275 276 277 278
        Returns the sum of values from the given observation.
        """

        # First compute whether the observation is terminal or not
        env = self.current_node.env
        is_terminal = env.is_terminal()
        max_depth_reached = depth >= self.max_depth
        dis_observation = self._to_discrete(observation)
        # print('Depth %d:\t %s' % (depth, dis_observation))

        if is_terminal or max_depth_reached:
            # print('MCTS went %d nodes deep' % depth)
Ashish Gaurav's avatar
Ashish Gaurav committed
279 280
            return self._value(
                observation), 0  # TODO: replace with final goal reward
Aravind Bk's avatar
Aravind Bk committed
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301

        Ns = self._get_visitation_count(observation)
        Nchildren = len(self.adj[self.curr_node_num])

        if Ns == 0 or Nchildren < np.sqrt(Ns):
            # Add new edge
            new_options = self._get_possible_options()
            # Randomly choose option
            new_option = np.random.choice(list(new_options))
            # Does the chosen option already exist?
            already_visted_options = self._get_already_visited_options(\
                self.curr_node_num)
            already_chosen_once = new_option in already_visted_options
            # If not visited, create it
            if not already_chosen_once:
                new_node_num, new_node_info = self._create_node(new_option)
                self.nodes[new_node_num] = new_node_info
                self.adj[new_node_num] = set()
                self.adj[self.curr_node_num].add(new_node_num)

        # Find o_star and do a transition, i.e. update curr_node
Ashish Gaurav's avatar
Ashish Gaurav committed
302 303 304
        # Simulate / lookup; first change next
        next_observation, episode_R, o_star = self.do_transition(
            observation, visualize=visualize)
Aravind Bk's avatar
Aravind Bk committed
305 306

        # Recursively select next node
Ashish Gaurav's avatar
Ashish Gaurav committed
307 308
        remaining_v, all_ep_R = self._select(
            next_observation, depth + 1, visualize=visualize)
Aravind Bk's avatar
Aravind Bk committed
309 310 311 312 313 314

        # Update values
        self.N[dis_observation] += 1
        self.M[(dis_observation, o_star)] += 1
        self.TR[(dis_observation, o_star)] += (episode_R + remaining_v)

Ashish Gaurav's avatar
Ashish Gaurav committed
315
        return self._value(observation), all_ep_R + episode_R
Aravind Bk's avatar
Aravind Bk committed
316 317

    def traverse(self, observation, visualize=False):
Ashish Gaurav's avatar
Ashish Gaurav committed
318 319
        """Do a complete traversal from root to leaf. Assumes the environment
        is reset and we are at the root node.
Aravind Bk's avatar
Aravind Bk committed
320 321 322 323 324 325 326 327 328 329 330

        Args:
            observation: observation from the environment
            visualize: whether or not to visualize low level steps

        Returns value of root node
        """

        return self._select(observation, visualize=visualize)

    def do_transition(self, observation, visualize=False):
Ashish Gaurav's avatar
Ashish Gaurav committed
331 332
        """Do a transition using UCB metric, with the latest observation from
        the episodic step.
Aravind Bk's avatar
Aravind Bk committed
333 334 335

        Args:
            observation: final observation from episodic step
Ashish Gaurav's avatar
Ashish Gaurav committed
336 337
            visualize: whether or not to visualize low level steps

Aravind Bk's avatar
Aravind Bk committed
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
        Returns o_star using UCB metric
        """

        # Choose an option to expand using UCB metric
        Q, Q1, Q2 = self._ucb_adjusted_q(observation, C=1)
        # Compute o_star = argmax_o Q(s, o)
        next_keys, next_values = list(Q.keys()), list(Q.values())
        _, o_star = next_keys[np.argmax(next_values)]
        # print('Current state: %s' % self._to_discrete(observation))
        # print('Choosing %s: %f, %f, sum:%f' % (o_star, Q1[o_star], Q2[o_star], np.max(next_values)))
        # Change current_option to reflect this o_star
        self.set_current_node(o_star)
        next_obs, eps_R, _, _ = self.step_current_node(\
            visualize_low_level_steps=visualize)
        return next_obs, eps_R, o_star

    def set_current_node(self, new_node_alias):
        next_option_nums = self.adj[self.curr_node_num]
        found_next_node = False
        for next_option_num in next_option_nums:
            next_option = self.nodes[next_option_num]["policy"]
            if next_option == new_node_alias:
                # print('Currently at %d:%s, switching to %d:%s' % (self.curr_node_num, self.curr_node_alias, next_option_num, new_node_alias))
                self.curr_node_num = next_option_num
                found_next_node = True
                break
        if self.curr_node_num != 0:
            assert found_next_node, "Couldn't transition to next node (%s) "\
                "from current_node (%s)" % (new_node_alias, self.curr_node_alias)
        self.curr_node_alias = new_node_alias
        self.current_node = self.low_level_policies[self.curr_node_alias]
        self.current_node.reset_maneuver()
        self.env.set_ego_info_text(self.curr_node_alias)

    def get_best_node(self, observation, use_ucb=False):
        # Return node alias for best node from current node
        Q, Q1, Q2 = self._ucb_adjusted_q(observation, C=1)
        if use_ucb:
            next_keys, next_values = list(Q.keys()), list(Q.values())
            _, o_star = next_keys[np.argmax(next_values)]
            print(Q)
        else:
            next_keys, next_values = list(Q1.keys()), list(Q1.values())
            o_star = next_keys[np.argmax(next_values)]
            print(Q1)
Ashish Gaurav's avatar
Ashish Gaurav committed
383
        return o_star