mcts.py 7.8 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5
from env.simple_intersection import SimpleIntersectionEnv
from options.options_loader import OptionsGraph
from backends import DDPGLearner, DQNLearner, MCTSLearner
import numpy as np
import tqdm
Unknown's avatar
Unknown committed
6
import argparse
Ashish Gaurav's avatar
Ashish Gaurav committed
7
import time, datetime
Aravind Bk's avatar
Aravind Bk committed
8 9
import sys

Ashish Gaurav's avatar
Ashish Gaurav committed
10

Aravind Bk's avatar
Aravind Bk committed
11 12 13 14 15 16 17
class Logger(object):
    def __init__(self):
        self.terminal = sys.stdout
        self.log = open("logfile.log", "a")

    def write(self, message):
        self.terminal.write(message)
Ashish Gaurav's avatar
Ashish Gaurav committed
18
        self.log.write(message)
Aravind Bk's avatar
Aravind Bk committed
19 20 21 22 23 24
        self.log.flush()

    def flush(self):
        #this flush method is needed for python 3 compatibility.
        #this handles the flush command by doing nothing.
        #you might want to specify some extra behavior here.
Ashish Gaurav's avatar
Ashish Gaurav committed
25 26
        pass

Aravind Bk's avatar
Aravind Bk committed
27 28 29

sys.stdout = Logger()

Ashish Gaurav's avatar
Ashish Gaurav committed
30

Ashish Gaurav's avatar
Ashish Gaurav committed
31 32 33 34
def mcts_evaluation(depth,
                    nb_traversals,
                    nb_episodes,
                    nb_trials,
Ashish Gaurav's avatar
Ashish Gaurav committed
35
                    visualize=False,
36 37 38
                    debug=False,
                    pretrained=True,
                    highlevel_policy_file="highlevel_weights.h5f"):
Ashish Gaurav's avatar
Ashish Gaurav committed
39 40
    """Do RL of the low-level policy of the given maneuver and test it.

Aravind Bk's avatar
Aravind Bk committed
41
    Args:
Ashish Gaurav's avatar
Ashish Gaurav committed
42 43 44 45
     depth: depth of each tree search
     nb_traversals: number of MCTS traversals per episodes
     nb_episodes: number of episodes per trial
     nb_trials: number of trials
Aravind Bk's avatar
Aravind Bk committed
46
     visualize: visualization / rendering
Ashish Gaurav's avatar
Ashish Gaurav committed
47
     debug: whether or not to show debug information
Aravind Bk's avatar
Aravind Bk committed
48 49 50 51 52
    """

    # initialize the numpy random number generator
    np.random.seed()

Ashish Gaurav's avatar
Ashish Gaurav committed
53 54 55
    # load config and maneuvers
    options = OptionsGraph("mcts_config.json", SimpleIntersectionEnv,
        randomize_special_scenarios=True)
Aravind Bk's avatar
Aravind Bk committed
56 57
    options.load_trained_low_level_policies()

Ashish Gaurav's avatar
Ashish Gaurav committed
58
    # load high level policy for UCT prediction
Ashish Gaurav's avatar
Ashish Gaurav committed
59 60 61 62
    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        low_level_policies=options.maneuvers)
63 64 65 66

    if pretrained:
        highlevel_policy_file = "backends/trained_policies/highlevel/" + highlevel_policy_file
    agent.load_model(highlevel_policy_file)
Ashish Gaurav's avatar
Ashish Gaurav committed
67

Ashish Gaurav's avatar
Ashish Gaurav committed
68
    # set predictor
Ashish Gaurav's avatar
Ashish Gaurav committed
69
    options.set_controller_args(
Ashish Gaurav's avatar
Ashish Gaurav committed
70 71 72 73
        predictor=agent.get_softq_value_using_option_alias,
        max_depth=depth,
        nb_traversals=nb_traversals,
        debug=debug)
Aravind Bk's avatar
Aravind Bk committed
74

Ashish Gaurav's avatar
Ashish Gaurav committed
75
    # Evaluate
Ashish Gaurav's avatar
Ashish Gaurav committed
76 77
    print("\nConducting {} trials of {} episodes each".format(
        nb_trials, nb_episodes))
Ashish Gaurav's avatar
Ashish Gaurav committed
78 79
    overall_reward_list = []
    overall_success_accuracy = []
80
    overall_termination_reason_list = {}
Ashish Gaurav's avatar
Ashish Gaurav committed
81
    for num_tr in range(nb_trials):
Aravind Bk's avatar
Aravind Bk committed
82
        num_successes = 0
Ashish Gaurav's avatar
Ashish Gaurav committed
83
        reward_list = []
84
        trial_termination_reason_counter = {}
Aravind Bk's avatar
Aravind Bk committed
85 86 87 88
        for num_ep in range(nb_episodes):
            init_obs = options.reset()
            episode_reward = 0
            first_time = True
Ashish Gaurav's avatar
Ashish Gaurav committed
89
            start_time = time.time()
Aravind Bk's avatar
Aravind Bk committed
90 91 92 93
            while not options.env.is_terminal():
                if first_time:
                    first_time = False
                else:
Ashish Gaurav's avatar
Ashish Gaurav committed
94
                    # print('Stepping through ...')
Ashish Gaurav's avatar
Ashish Gaurav committed
95
                    features, R, terminal, info = options.controller.\
Ashish Gaurav's avatar
Ashish Gaurav committed
96
                        step_current_node(visualize_low_level_steps=visualize)
Aravind Bk's avatar
Aravind Bk committed
97
                    episode_reward += R
Ashish Gaurav's avatar
Ashish Gaurav committed
98 99 100
                    # print('Intermediate Reward: %f (ego x = %f)' %
                    #       (R, options.env.vehs[0].x))
                    # print('')
101 102 103 104 105 106 107
                    if terminal:
                        if 'episode_termination_reason' in info:
                            termination_reason = info['episode_termination_reason']
                            if termination_reason in trial_termination_reason_counter:
                                trial_termination_reason_counter[termination_reason] += 1
                            else:
                                trial_termination_reason_counter[termination_reason] = 1
Aravind Bk's avatar
Aravind Bk committed
108 109
                if options.controller.can_transition():
                    options.controller.do_transition()
Ashish Gaurav's avatar
Ashish Gaurav committed
110 111 112 113 114 115 116
            end_time = time.time()
            total_time = int(end_time-start_time)
            if options.env.goal_achieved:
                num_successes += 1
            print('Episode {}: Reward = {} ({})'.format(num_ep, episode_reward,
                datetime.timedelta(seconds=total_time)))
            reward_list += [episode_reward]
117 118 119 120 121 122 123 124

        for reason, count in trial_termination_reason_counter.items():
            if reason in overall_termination_reason_list:
                overall_termination_reason_list[reason].append(count)
            else:
                overall_termination_reason_list[reason] = [count]

        print("\nTrial {}: Reward = (Avg: {}, Std: {}), Successes: {}/{}".\
Ashish Gaurav's avatar
Ashish Gaurav committed
125 126
            format(num_tr, np.mean(reward_list), np.std(reward_list), \
                num_successes, nb_episodes))
127 128 129 130 131 132 133
        print("Trial {} Termination reason(s):".format(num_tr))
        for reason, count_list in trial_termination_reason_counter.items():
            count_list = np.array(count_list)
            print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
                                                np.std(count_list)))
        print("\n")

Ashish Gaurav's avatar
Ashish Gaurav committed
134 135
        overall_reward_list += reward_list
        overall_success_accuracy += [num_successes * 1.0 / nb_episodes]
136 137 138

    print("===========================")
    print('Overall: Reward = (Avg: {}, Std: {}), Success = (Avg: {}, Std: {})\n'.\
Ashish Gaurav's avatar
Ashish Gaurav committed
139 140
        format(np.mean(overall_reward_list), np.std(overall_reward_list),
            np.mean(overall_success_accuracy), np.std(overall_success_accuracy)))
Ashish Gaurav's avatar
Ashish Gaurav committed
141

142 143 144 145 146 147
    print("Termination reason(s):")
    for reason, count_list in overall_termination_reason_list.items():
        count_list = np.array(count_list)
        print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
                                            np.std(count_list)))

Aravind Bk's avatar
Aravind Bk committed
148
if __name__ == "__main__":
Unknown's avatar
Unknown committed
149 150 151
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--evaluate",
Ashish Gaurav's avatar
Ashish Gaurav committed
152
        help="Evaluate over n trials, no visualization by default.",
Unknown's avatar
Unknown committed
153
        action="store_true")
154 155 156 157 158 159 160 161 162
    parser.add_argument(
        "--test",
        help="Tests MCTS for 100 episodes by default.",
        action="store_true")
    parser.add_argument(
        "--nb_episodes_for_test",
        help="Number of episodes to test/evaluate. Default is 100",
        default=10,
        type=int)
Unknown's avatar
Unknown committed
163 164
    parser.add_argument(
        "--visualize",
Ashish Gaurav's avatar
Ashish Gaurav committed
165
        help=
Ashish Gaurav's avatar
Ashish Gaurav committed
166
        "Visualize the training.",
Unknown's avatar
Unknown committed
167 168
        action="store_true")
    parser.add_argument(
Ashish Gaurav's avatar
Ashish Gaurav committed
169
        "--depth",
170 171
        help="Max depth of tree per episode. Default is 5",
        default=5,
Ashish Gaurav's avatar
Ashish Gaurav committed
172
        type=int)
Unknown's avatar
Unknown committed
173
    parser.add_argument(
Ashish Gaurav's avatar
Ashish Gaurav committed
174
        "--nb_traversals",
175 176
        help="Number of traversals to perform per episode. Default is 50",
        default=50,
Ashish Gaurav's avatar
Ashish Gaurav committed
177
        type=int)
Unknown's avatar
Unknown committed
178
    parser.add_argument(
Ashish Gaurav's avatar
Ashish Gaurav committed
179
        "--nb_episodes",
180 181
        help="Number of episodes per trial to evaluate. Default is 100",
        default=100,
Ashish Gaurav's avatar
Ashish Gaurav committed
182
        type=int)
Unknown's avatar
Unknown committed
183 184
    parser.add_argument(
        "--nb_trials",
185 186
        help="Number of trials to evaluate. Default is 10",
        default=10,
Ashish Gaurav's avatar
Ashish Gaurav committed
187
        type=int)
Unknown's avatar
Unknown committed
188
    parser.add_argument(
Ashish Gaurav's avatar
Ashish Gaurav committed
189 190 191
        "--debug",
        help="Show debug output. Default is false",
        action="store_true")
192 193 194 195 196
    parser.add_argument(
        "--highlevel_policy_in_root",
        help=
        "Use saved high-level policy in root of project rather than backends/trained_policies/highlevel/",
        action="store_true")
Unknown's avatar
Unknown committed
197 198 199 200

    args = parser.parse_args()

    if args.evaluate:
Ashish Gaurav's avatar
Ashish Gaurav committed
201
        mcts_evaluation(
Ashish Gaurav's avatar
Ashish Gaurav committed
202 203 204 205
            depth=args.depth,
            nb_traversals=args.nb_traversals,
            nb_episodes=args.nb_episodes,
            nb_trials=args.nb_trials,
Ashish Gaurav's avatar
Ashish Gaurav committed
206
            visualize=args.visualize,
207 208 209 210 211 212 213 214 215 216 217 218
            debug=args.debug,
            pretrained=not args.highlevel_policy_in_root)
    elif args.test:
        mcts_evaluation(
            depth=args.depth,
            nb_traversals=args.nb_traversals,
            nb_episodes=args.nb_episodes,
            nb_trials=1,
            visualize=True,
            debug=args.debug,
            pretrained=not args.highlevel_policy_in_root)