mcts.py 11.5 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5
from env.simple_intersection import SimpleIntersectionEnv
from options.options_loader import OptionsGraph
from backends import DDPGLearner, DQNLearner, MCTSLearner
import numpy as np
import tqdm
Unknown's avatar
Unknown committed
6
import argparse
Aravind Bk's avatar
Aravind Bk committed
7 8 9

import sys

Ashish Gaurav's avatar
Ashish Gaurav committed
10

Aravind Bk's avatar
Aravind Bk committed
11 12 13 14 15 16 17
class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout
        self.log = open("logfile.log", "a")

    def write(self, message):
        self.terminal.write(message)
Ashish Gaurav's avatar
Ashish Gaurav committed
18
        self.log.write(message)
Aravind Bk's avatar
Aravind Bk committed
19 20 21 22 23 24
        self.log.flush()

    def flush(self):
        #this flush method is needed for python 3 compatibility.
        #this handles the flush command by doing nothing.
        #you might want to specify some extra behavior here.
Ashish Gaurav's avatar
Ashish Gaurav committed
25 26
        pass

Aravind Bk's avatar
Aravind Bk committed
27 28 29

sys.stdout = Logger()

Ashish Gaurav's avatar
Ashish Gaurav committed
30

Aravind Bk's avatar
Aravind Bk committed
31
# TODO: make a separate file for this function.
Ashish Gaurav's avatar
Ashish Gaurav committed
32 33 34 35 36
def mcts_training(nb_traversals,
                  save_every=20,
                  visualize=False,
                  load_saved=False,
                  save_file="mcts.pickle"):
Ashish Gaurav's avatar
Ashish Gaurav committed
37 38
    """Do RL of the low-level policy of the given maneuver and test it.

Aravind Bk's avatar
Aravind Bk committed
39 40 41 42 43 44 45 46 47 48 49 50 51
    Args:
     nb_traversals: number of MCTS traversals
     save_every: save at every these many traversals
     visualize: visualization / rendering
    """

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv)
    options.load_trained_low_level_policies()

Ashish Gaurav's avatar
Ashish Gaurav committed
52 53 54 55
    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        low_level_policies=options.maneuvers)
Aravind Bk's avatar
Aravind Bk committed
56

Ashish Gaurav's avatar
Ashish Gaurav committed
57 58 59 60
    agent.load_model(
        "backends/trained_policies/highlevel/highlevel_weights.h5f")
    options.set_controller_args(
        predictor=agent.get_softq_value_using_option_alias)
Aravind Bk's avatar
Aravind Bk committed
61
    options.controller.max_depth = 20
Unknown's avatar
Unknown committed
62 63 64

    if load_saved:
        options.controller.load_model(save_file)
Aravind Bk's avatar
Aravind Bk committed
65

Ashish Gaurav's avatar
Ashish Gaurav committed
66
    total_epochs = nb_traversals // save_every
Aravind Bk's avatar
Aravind Bk committed
67 68 69 70 71 72 73 74
    trav_num = 1
    print('Total number of epochs = %d' % total_epochs)
    for num_epoch in range(total_epochs):
        last_rewards = []
        beg_trav_num = trav_num
        for num_traversal in tqdm.tqdm(range(save_every)):
            options.controller.curr_node_num = 0
            init_obs = options.reset()
Ashish Gaurav's avatar
Ashish Gaurav committed
75 76
            v, all_ep_R = options.controller.traverse(
                init_obs, visualize=visualize)
Unknown's avatar
Unknown committed
77

Aravind Bk's avatar
Aravind Bk committed
78 79
            last_rewards += [all_ep_R]
            trav_num += 1
Unknown's avatar
Unknown committed
80
        options.controller.save_model(save_file)
Aravind Bk's avatar
Aravind Bk committed
81
        success = lambda x: x > 50
Ashish Gaurav's avatar
Ashish Gaurav committed
82 83
        success_rate = np.sum(list(map(
            success, last_rewards))) / (len(last_rewards) * 1.0)
Aravind Bk's avatar
Aravind Bk committed
84
        print('success rate: %f' % success_rate)
Ashish Gaurav's avatar
Ashish Gaurav committed
85 86 87
        print('Average Reward (%d-%d): %f\n' % (beg_trav_num, trav_num - 1,
                                                np.mean(last_rewards)))

Aravind Bk's avatar
Aravind Bk committed
88

Ashish Gaurav's avatar
Ashish Gaurav committed
89 90 91 92 93
def mcts_evaluation(nb_traversals,
                    num_trials=5,
                    visualize=False,
                    save_file="mcts.pickle",
                    pretrained=False):
Ashish Gaurav's avatar
Ashish Gaurav committed
94 95
    """Do RL of the low-level policy of the given maneuver and test it.

Aravind Bk's avatar
Aravind Bk committed
96 97 98 99 100 101 102 103 104 105 106 107 108
    Args:
     nb_traversals: number of MCTS traversals
     save_every: save at every these many traversals
     visualize: visualization / rendering
    """

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv)
    options.load_trained_low_level_policies()

Ashish Gaurav's avatar
Ashish Gaurav committed
109 110 111 112
    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        low_level_policies=options.maneuvers)
Aravind Bk's avatar
Aravind Bk committed
113

Ashish Gaurav's avatar
Ashish Gaurav committed
114 115 116 117
    agent.load_model(
        "backends/trained_policies/highlevel/highlevel_weights.h5f")
    options.set_controller_args(
        predictor=agent.get_softq_value_using_option_alias)
Aravind Bk's avatar
Aravind Bk committed
118 119
    options.controller.max_depth = 20

Unknown's avatar
Unknown committed
120 121 122
    if pretrained:
        save_file = "backends/trained_policies/mcts/" + save_file

Aravind Bk's avatar
Aravind Bk committed
123 124 125 126
    success_list = []
    print('Total number of trials = %d' % num_trials)
    for trial in range(num_trials):
        num_successes = 0
Unknown's avatar
Unknown committed
127
        options.controller.load_model(save_file)
Aravind Bk's avatar
Aravind Bk committed
128 129 130
        for num_traversal in tqdm.tqdm(range(nb_traversals)):
            options.controller.curr_node_num = 0
            init_obs = options.reset()
Ashish Gaurav's avatar
Ashish Gaurav committed
131 132
            v, all_ep_R = options.controller.traverse(
                init_obs, visualize=visualize)
Aravind Bk's avatar
Aravind Bk committed
133 134 135 136 137
            if all_ep_R > 50:
                num_successes += 1
        print("\nTrial {}: success: {}".format(trial + 1, num_successes))
        success_list.append(num_successes)

Ashish Gaurav's avatar
Ashish Gaurav committed
138 139
    print("\nSuccess: Avg: {}, Std: {}".format(
        np.mean(success_list), np.std(success_list)))
Aravind Bk's avatar
Aravind Bk committed
140 141


Ashish Gaurav's avatar
Ashish Gaurav committed
142
def online_mcts(nb_episodes=10):
Aravind Bk's avatar
Aravind Bk committed
143 144 145 146 147 148 149 150 151
    # MCTS visualization is off

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv)
    options.load_trained_low_level_policies()

Ashish Gaurav's avatar
Ashish Gaurav committed
152 153 154 155 156 157 158 159 160
    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        low_level_policies=options.maneuvers)

    agent.load_model(
        "backends/trained_policies/highlevel/highlevel_weights_772.h5f")
    options.set_controller_args(
        predictor=agent.get_softq_value_using_option_alias)
Aravind Bk's avatar
Aravind Bk committed
161 162 163 164 165 166 167 168 169 170 171 172

    # Loop
    num_successes = 0
    for num_ep in range(nb_episodes):
        init_obs = options.reset()
        episode_reward = 0
        first_time = True
        while not options.env.is_terminal():
            if first_time:
                first_time = False
            else:
                print('Stepping through ...')
Ashish Gaurav's avatar
Ashish Gaurav committed
173 174
                features, R, terminal, info = options.controller.\
                    step_current_node(visualize_low_level_steps=True)
Aravind Bk's avatar
Aravind Bk committed
175
                episode_reward += R
Ashish Gaurav's avatar
Ashish Gaurav committed
176 177
                print('Intermediate Reward: %f (ego x = %f)' %
                      (R, options.env.vehs[0].x))
Aravind Bk's avatar
Aravind Bk committed
178 179 180 181 182 183 184 185 186
                print('')
            if options.controller.can_transition():
                options.controller.do_transition()
        print('')
        print('EPISODE %d: Reward = %f' % (num_ep, episode_reward))
        print('')
        print('')
        if episode_reward > 50: num_successes += 1

Ashish Gaurav's avatar
Ashish Gaurav committed
187
    print("Policy succeeded {} times!".format(num_successes))
Aravind Bk's avatar
Aravind Bk committed
188 189 190 191 192 193 194 195 196 197 198 199


def evaluate_online_mcts(nb_episodes=20, nb_trials=5):
    # MCTS visualization is off

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv)
    options.load_trained_low_level_policies()

Ashish Gaurav's avatar
Ashish Gaurav committed
200 201 202 203
    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        low_level_policies=options.maneuvers)
Aravind Bk's avatar
Aravind Bk committed
204

Ashish Gaurav's avatar
Ashish Gaurav committed
205 206 207 208
    agent.load_model(
        "backends/trained_policies/highlevel/highlevel_weights_772.h5f")
    options.set_controller_args(
        predictor=agent.get_softq_value_using_option_alias)
Aravind Bk's avatar
Aravind Bk committed
209

Ashish Gaurav's avatar
Ashish Gaurav committed
210 211
    print("\nConducting {} trials of {} episodes each".format(
        nb_trials, nb_episodes))
Aravind Bk's avatar
Aravind Bk committed
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
    success_list = []
    termination_reason_list = {}
    for trial in range(nb_trials):
        # Loop
        num_successes = 0
        termination_reason_counter = {}
        for num_ep in range(nb_episodes):
            init_obs = options.reset()
            episode_reward = 0
            first_time = True
            while not options.env.is_terminal():
                if first_time:
                    first_time = False
                else:
                    print('Stepping through ...')
Ashish Gaurav's avatar
Ashish Gaurav committed
227 228
                    features, R, terminal, info = options.controller.\
                        step_current_node(visualize_low_level_steps=True)
Aravind Bk's avatar
Aravind Bk committed
229
                    episode_reward += R
Ashish Gaurav's avatar
Ashish Gaurav committed
230 231
                    print('Intermediate Reward: %f (ego x = %f)' %
                          (R, options.env.vehs[0].x))
Aravind Bk's avatar
Aravind Bk committed
232 233 234
                    print('')
                    if terminal:
                        if 'episode_termination_reason' in info:
Ashish Gaurav's avatar
Ashish Gaurav committed
235 236
                            termination_reason = info[
                                'episode_termination_reason']
Aravind Bk's avatar
Aravind Bk committed
237
                            if termination_reason in termination_reason_counter:
Ashish Gaurav's avatar
Ashish Gaurav committed
238 239
                                termination_reason_counter[
                                    termination_reason] += 1
Aravind Bk's avatar
Aravind Bk committed
240
                            else:
Ashish Gaurav's avatar
Ashish Gaurav committed
241 242
                                termination_reason_counter[
                                    termination_reason] = 1
Aravind Bk's avatar
Aravind Bk committed
243 244 245 246 247 248 249 250
                if options.controller.can_transition():
                    options.controller.do_transition()
            print('')
            print('EPISODE %d: Reward = %f' % (num_ep, episode_reward))
            print('')
            print('')
            if episode_reward > 50: num_successes += 1

Ashish Gaurav's avatar
Ashish Gaurav committed
251
        print("\nTrial {}: success: {}".format(trial + 1, num_successes))
Aravind Bk's avatar
Aravind Bk committed
252 253 254 255 256 257 258 259
        success_list.append(num_successes)
        for reason, count in termination_reason_counter.items():
            if reason in termination_reason_list:
                termination_reason_list[reason].append(count)
            else:
                termination_reason_list[reason] = [count]

    success_list = np.array(success_list)
Ashish Gaurav's avatar
Ashish Gaurav committed
260 261 262
    print("\nSuccess: Avg: {}, Std: {}".format(
        np.mean(success_list), np.std(success_list)))
    print("Termination reason(s):")
Aravind Bk's avatar
Aravind Bk committed
263 264
    for reason, count_list in termination_reason_list.items():
        count_list = np.array(count_list)
Ashish Gaurav's avatar
Ashish Gaurav committed
265
        print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
Aravind Bk's avatar
Aravind Bk committed
266 267
                                            np.std(count_list)))

Ashish Gaurav's avatar
Ashish Gaurav committed
268

Aravind Bk's avatar
Aravind Bk committed
269
if __name__ == "__main__":
Unknown's avatar
Unknown committed
270 271 272
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train",
Ashish Gaurav's avatar
Ashish Gaurav committed
273 274
        help=
        "Train an offline mcts with default settings. Always saved in root folder.",
Unknown's avatar
Unknown committed
275 276 277 278
        action="store_true")
    parser.add_argument(
        "--evaluate",
        help="Evaluate over n trials. "
Ashish Gaurav's avatar
Ashish Gaurav committed
279
        "Uses backends/trained_policies/mcts/mcts.pickle by default",
Unknown's avatar
Unknown committed
280 281 282
        action="store_true")
    parser.add_argument(
        "--saved_policy_in_root",
Ashish Gaurav's avatar
Ashish Gaurav committed
283 284
        help=
        "Use saved policies in root of project rather than backends/trained_policies/mcts/",
Unknown's avatar
Unknown committed
285 286 287 288 289 290 291
        action="store_true")
    parser.add_argument(
        "--load_saved",
        help="Load a saved policy from root folder first before training",
        action="store_true")
    parser.add_argument(
        "--visualize",
Ashish Gaurav's avatar
Ashish Gaurav committed
292 293
        help=
        "Visualize the training. Testing is always visualized. Evaluation is not visualized by default",
Unknown's avatar
Unknown committed
294 295 296
        action="store_true")
    parser.add_argument(
        "--nb_traversals",
Ashish Gaurav's avatar
Ashish Gaurav committed
297 298 299
        help="Number of traversals to perform. Default is 1000",
        default=1000,
        type=int)
Unknown's avatar
Unknown committed
300 301
    parser.add_argument(
        "--save_every",
Ashish Gaurav's avatar
Ashish Gaurav committed
302 303 304 305
        help=
        "Saves every n traversals. Saves in root by default. Default is 500",
        default=500,
        type=int)
Unknown's avatar
Unknown committed
306 307
    parser.add_argument(
        "--nb_traversals_for_test",
Ashish Gaurav's avatar
Ashish Gaurav committed
308 309 310
        help="Number of episodes to evaluate. Default is 100",
        default=100,
        type=int)
Unknown's avatar
Unknown committed
311 312
    parser.add_argument(
        "--nb_trials",
Ashish Gaurav's avatar
Ashish Gaurav committed
313 314 315
        help="Number of trials to evaluate. Default is 10",
        default=10,
        type=int)
Unknown's avatar
Unknown committed
316 317
    parser.add_argument(
        "--save_file",
Ashish Gaurav's avatar
Ashish Gaurav committed
318 319
        help=
        "filename to save/load the trained policy. Location is as specified by --saved_policy_in_root. Default name is mcts.pickle",
Unknown's avatar
Unknown committed
320 321 322 323 324
        default="mcts.pickle")

    args = parser.parse_args()

    if args.train:
Ashish Gaurav's avatar
Ashish Gaurav committed
325 326 327 328 329 330
        mcts_training(
            nb_traversals=args.nb_traversals,
            save_every=args.save_every,
            visualize=args.visualize,
            load_saved=args.load_saved,
            save_file=args.save_file)
Unknown's avatar
Unknown committed
331
    if args.evaluate:
Ashish Gaurav's avatar
Ashish Gaurav committed
332 333 334 335 336 337
        mcts_evaluation(
            nb_traversals=args.nb_traversals_for_test,
            num_trials=args.nb_trials,
            visualize=args.visualize,
            pretrained=not args.saved_policy_in_root,
            save_file=args.save_file)