high_level_policy_main.py 8.7 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.kerasrl_learner import DQNLearner
import os
import argparse


# TODO: make a separate file for this function.
Ashish Gaurav's avatar
Ashish Gaurav committed
10 11 12 13
def high_level_policy_training(nb_steps=25000,
                               load_weights=False,
                               training=True,
                               testing=True,
14
                               nb_episodes_for_test=20,
Ashish Gaurav's avatar
Ashish Gaurav committed
15 16 17 18
                               max_nb_steps=100,
                               visualize=False,
                               tensorboard=False,
                               save_path="highlevel_weights.h5f"):
Ashish Gaurav's avatar
Ashish Gaurav committed
19 20
    """Do RL of the high-level policy and test it.

Aravind Bk's avatar
Aravind Bk committed
21 22 23 24 25 26 27 28 29 30
    Args:
     nb_steps: the number of steps to perform RL
     load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs)
     training: True to enable training
     testing: True to enable testing
     nb_episodes_for_test: the number of episodes for testing
    """
    # initialize the numpy random number generator
    np.random.seed()

Ashish Gaurav's avatar
Ashish Gaurav committed
31 32
    if not (isinstance(load_weights, bool) and isinstance(training, bool)
            and isinstance(testing, bool)):
Aravind Bk's avatar
Aravind Bk committed
33 34 35
        raise ValueError("Type error: the variable has to be boolean.")

    if not load_weights and not training:
Ashish Gaurav's avatar
Ashish Gaurav committed
36 37 38
        raise ValueError(
            "Both load_weights and training are False: no learning and no loading weights."
        )
Aravind Bk's avatar
Aravind Bk committed
39 40 41 42 43 44

    if not isinstance(nb_steps, int) or nb_steps <= 0:
        raise ValueError("nb_steps has to be a positive number.")

    global options

Ashish Gaurav's avatar
Ashish Gaurav committed
45 46 47 48 49 50
    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        target_model_update=1e-3,
        delta_clip=100,
        low_level_policies=options.maneuvers)
Aravind Bk's avatar
Aravind Bk committed
51 52 53 54 55 56 57

    if load_weights:
        agent.load_model(save_path)

    if training:
        if visualize:
            options.visualize_low_level_steps = True
Ashish Gaurav's avatar
Ashish Gaurav committed
58 59 60 61 62
        agent.train(
            options,
            nb_steps=nb_steps,
            nb_max_episode_steps=max_nb_steps,
            tensorboard=tensorboard)
Aravind Bk's avatar
Aravind Bk committed
63 64 65
        agent.save_model(save_path)

    if testing:
66
        high_level_policy_testing(nb_episodes_for_test=nb_episodes_for_test)
Aravind Bk's avatar
Aravind Bk committed
67 68 69 70

    return agent


Ashish Gaurav's avatar
Ashish Gaurav committed
71 72 73 74
def high_level_policy_testing(nb_episodes_for_test=100,
                              trained_agent_file="highlevel_weights.h5f",
                              pretrained=False,
                              visualize=True):
Aravind Bk's avatar
Aravind Bk committed
75 76
    global options

Ashish Gaurav's avatar
Ashish Gaurav committed
77 78 79 80
    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        low_level_policies=options.maneuvers)
Aravind Bk's avatar
Aravind Bk committed
81 82 83 84 85 86 87

    if pretrained:
        trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file

    agent.load_model(trained_agent_file)
    options.set_controller_policy(agent.predict)

Ashish Gaurav's avatar
Ashish Gaurav committed
88 89
    agent.test_model(
        options, nb_episodes=nb_episodes_for_test, visualize=visualize)
Aravind Bk's avatar
Aravind Bk committed
90

Ashish Gaurav's avatar
Ashish Gaurav committed
91 92 93 94 95 96

def evaluate_high_level_policy(nb_episodes_for_test=100,
                               nb_trials=10,
                               trained_agent_file="highlevel_weights.h5f",
                               pretrained=False,
                               visualize=False):
Aravind Bk's avatar
Aravind Bk committed
97 98
    global options

Ashish Gaurav's avatar
Ashish Gaurav committed
99 100 101 102
    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        low_level_policies=options.maneuvers)
Aravind Bk's avatar
Aravind Bk committed
103 104 105 106 107 108 109 110 111

    if pretrained:
        trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file

    agent.load_model(trained_agent_file)
    options.set_controller_policy(agent.predict)

    success_list = []
    termination_reason_list = {}
Ashish Gaurav's avatar
Ashish Gaurav committed
112 113
    print("\nConducting {} trials of {} episodes each".format(
        nb_trials, nb_episodes_for_test))
Aravind Bk's avatar
Aravind Bk committed
114
    for trial in range(nb_trials):
Ashish Gaurav's avatar
Ashish Gaurav committed
115 116 117
        current_success, current_termination_reason = agent.test_model(
            options, nb_episodes=nb_episodes_for_test, visualize=visualize)
        print("\nTrial {}: success: {}".format(trial + 1, current_success))
Aravind Bk's avatar
Aravind Bk committed
118 119 120 121 122 123 124 125
        success_list.append(current_success)
        for reason, count in current_termination_reason.items():
            if reason in termination_reason_list:
                termination_reason_list[reason].append(count)
            else:
                termination_reason_list[reason] = [count]

    success_list = np.array(success_list)
Ashish Gaurav's avatar
Ashish Gaurav committed
126 127 128
    print("\nSuccess: Avg: {}, Std: {}".format(
        np.mean(success_list), np.std(success_list)))
    print("Termination reason(s):")
Aravind Bk's avatar
Aravind Bk committed
129 130
    for reason, count_list in termination_reason_list.items():
        count_list = np.array(count_list)
Ashish Gaurav's avatar
Ashish Gaurav committed
131
        print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
Aravind Bk's avatar
Aravind Bk committed
132 133
                                            np.std(count_list)))

Ashish Gaurav's avatar
Ashish Gaurav committed
134 135 136 137 138 139 140

def find_good_high_level_policy(nb_steps=25000,
                                load_weights=False,
                                nb_episodes_for_test=100,
                                visualize=False,
                                tensorboard=False,
                                save_path="./highlevel_weights.h5f"):
Aravind Bk's avatar
Aravind Bk committed
141 142 143
    max_num_successes = 0
    current_success = 0
    while current_success < 0.95 * nb_episodes_for_test:
Ashish Gaurav's avatar
Ashish Gaurav committed
144 145 146 147 148 149
        agent = high_level_policy_training(
            nb_steps=nb_steps,
            load_weights=load_weights,
            visualize=visualize,
            tensorboard=tensorboard,
            testing=False)
Aravind Bk's avatar
Aravind Bk committed
150 151
        options.set_controller_policy(agent.predict)

Ashish Gaurav's avatar
Ashish Gaurav committed
152 153
        current_success, termination_reason_counter = agent.test_model(
            options, nb_episodes=nb_episodes_for_test, visualize=visualize)
Aravind Bk's avatar
Aravind Bk committed
154 155

        if current_success > max_num_successes:
Ashish Gaurav's avatar
Ashish Gaurav committed
156 157
            os.rename(save_path,
                      "highlevel_weights_{}.h5f".format(current_success))
Aravind Bk's avatar
Aravind Bk committed
158 159
            max_num_successes = current_success

Ashish Gaurav's avatar
Ashish Gaurav committed
160

Aravind Bk's avatar
Aravind Bk committed
161 162 163 164
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train",
Ashish Gaurav's avatar
Ashish Gaurav committed
165 166
        help=
        "Train a high level policy with default settings. Always saved in root folder. Always tests after training",
Aravind Bk's avatar
Aravind Bk committed
167 168 169
        action="store_true")
    parser.add_argument(
        "--test",
Ashish Gaurav's avatar
Ashish Gaurav committed
170 171
        help=
        "Test a saved high level policy. Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
Aravind Bk's avatar
Aravind Bk committed
172 173 174 175
        action="store_true")
    parser.add_argument(
        "--evaluate",
        help="Evaluate a saved high level policy over n trials. "
Ashish Gaurav's avatar
Ashish Gaurav committed
176
        "Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
Aravind Bk's avatar
Aravind Bk committed
177 178 179
        action="store_true")
    parser.add_argument(
        "--saved_policy_in_root",
Ashish Gaurav's avatar
Ashish Gaurav committed
180 181
        help=
        "Use saved policies in root of project rather than backends/trained_policies/highlevel/",
Aravind Bk's avatar
Aravind Bk committed
182 183 184 185 186 187 188 189 190 191 192
        action="store_true")
    parser.add_argument(
        "--load_weights",
        help="Load a saved policy from root folder first before training",
        action="store_true")
    parser.add_argument(
        "--tensorboard",
        help="Use tensorboard while training",
        action="store_true")
    parser.add_argument(
        "--visualize",
Ashish Gaurav's avatar
Ashish Gaurav committed
193 194
        help=
        "Visualize the training. Testing is always visualized. Evaluation is not visualized by default",
Aravind Bk's avatar
Aravind Bk committed
195 196 197
        action="store_true")
    parser.add_argument(
        "--nb_steps",
Ashish Gaurav's avatar
Ashish Gaurav committed
198 199 200
        help="Number of steps to train for. Default is 25000",
        default=25000,
        type=int)
Aravind Bk's avatar
Aravind Bk committed
201 202
    parser.add_argument(
        "--nb_episodes_for_test",
Ashish Gaurav's avatar
Ashish Gaurav committed
203 204 205
        help="Number of episodes to test/evaluate. Default is 100",
        default=100,
        type=int)
Aravind Bk's avatar
Aravind Bk committed
206 207
    parser.add_argument(
        "--nb_trials",
Ashish Gaurav's avatar
Ashish Gaurav committed
208 209 210
        help="Number of trials to evaluate. Default is 10",
        default=10,
        type=int)
Aravind Bk's avatar
Aravind Bk committed
211 212
    parser.add_argument(
        "--save_file",
Ashish Gaurav's avatar
Ashish Gaurav committed
213 214
        help=
        "filename to save/load the trained policy. Location is as specified by --saved_policy_in_root",
Aravind Bk's avatar
Aravind Bk committed
215 216 217 218 219
        default="highlevel_weights.h5f")

    args = parser.parse_args()

    # load options graph
Ashish Gaurav's avatar
Ashish Gaurav committed
220 221
    options = OptionsGraph(
        "config.json", SimpleIntersectionEnv, randomize_special_scenarios=True)
Aravind Bk's avatar
Aravind Bk committed
222 223 224
    options.load_trained_low_level_policies()

    if args.train:
Ashish Gaurav's avatar
Ashish Gaurav committed
225 226 227 228 229 230
        high_level_policy_training(
            nb_steps=args.nb_steps,
            load_weights=args.load_weights,
            save_path=args.save_file,
            tensorboard=args.tensorboard,
            visualize=args.visualize)
Aravind Bk's avatar
Aravind Bk committed
231 232

    if args.test:
Ashish Gaurav's avatar
Ashish Gaurav committed
233 234 235 236 237
        high_level_policy_testing(
            visualize=True,
            nb_episodes_for_test=args.nb_episodes_for_test,
            pretrained=not args.saved_policy_in_root,
            trained_agent_file=args.save_file)
Aravind Bk's avatar
Aravind Bk committed
238 239

    if args.evaluate:
Ashish Gaurav's avatar
Ashish Gaurav committed
240 241 242 243 244 245
        evaluate_high_level_policy(
            visualize=args.visualize,
            nb_episodes_for_test=args.nb_episodes_for_test,
            pretrained=not args.saved_policy_in_root,
            trained_agent_file=args.save_file,
            nb_trials=args.nb_trials)