high_level_policy_main.py 8.14 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.kerasrl_learner import DQNLearner
import os
import argparse


# TODO: make a separate file for this function.
def high_level_policy_training(nb_steps=25000, load_weights=False, training=True, testing=True, nb_episodes_for_test=10,
                               max_nb_steps=100, visualize=False, tensorboard=False,
                               save_path = "highlevel_weights.h5f"):
    """
    Do RL of the high-level policy and test it.
    Args:
     nb_steps: the number of steps to perform RL
     load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs)
     training: True to enable training
     testing: True to enable testing
     nb_episodes_for_test: the number of episodes for testing
    """
    # initialize the numpy random number generator
    np.random.seed()

    if not (isinstance(load_weights, bool) and isinstance(training, bool) and isinstance(testing, bool)):
        raise ValueError("Type error: the variable has to be boolean.")

    if not load_weights and not training:
        raise ValueError("Both load_weights and training are False: no learning and no loading weights.")

    if not isinstance(nb_steps, int) or nb_steps <= 0:
        raise ValueError("nb_steps has to be a positive number.")

    global options

    agent = DQNLearner(input_shape=(50,), nb_actions=options.get_number_of_nodes(),
                       target_model_update=1e-3,
                       delta_clip=100,
                       low_level_policies=options.maneuvers)

    if load_weights:
        agent.load_model(save_path)

    if training:
        if visualize:
            options.visualize_low_level_steps = True
        agent.train(options, nb_steps=nb_steps, nb_max_episode_steps=max_nb_steps, tensorboard=tensorboard)
        agent.save_model(save_path)

    if testing:
        options.set_controller_policy(agent.predict)
        agent.test_model(options, nb_episodes=nb_episodes_for_test)

    return agent


def high_level_policy_testing(nb_episodes_for_test=100, trained_agent_file="highlevel_weights.h5f",
                              pretrained=False, visualize=True):
    global options

    agent = DQNLearner(input_shape=(50,), nb_actions=options.get_number_of_nodes(),
                       low_level_policies=options.maneuvers)

    if pretrained:
        trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file

    agent.load_model(trained_agent_file)
    options.set_controller_policy(agent.predict)

    agent.test_model(options, nb_episodes=nb_episodes_for_test, visualize=visualize)

def evaluate_high_level_policy(nb_episodes_for_test=100, nb_trials=10,
                               trained_agent_file="highlevel_weights.h5f", pretrained=False, visualize=False):
    global options

    agent = DQNLearner(input_shape=(50,), nb_actions=options.get_number_of_nodes(),
                       low_level_policies=options.maneuvers)

    if pretrained:
        trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file

    agent.load_model(trained_agent_file)
    options.set_controller_policy(agent.predict)

    success_list = []
    termination_reason_list = {}
    print ("\nConducting {} trials of {} episodes each".format(nb_trials,nb_episodes_for_test))
    for trial in range(nb_trials):
        current_success, current_termination_reason = agent.test_model(options, nb_episodes=nb_episodes_for_test, visualize=visualize)
        print("\nTrial {}: success: {}".format(trial+1, current_success))
        success_list.append(current_success)
        for reason, count in current_termination_reason.items():
            if reason in termination_reason_list:
                termination_reason_list[reason].append(count)
            else:
                termination_reason_list[reason] = [count]

    success_list = np.array(success_list)
    print ("\nSuccess: Avg: {}, Std: {}".format(np.mean(success_list), np.std(success_list)))
    print ("Termination reason(s):")
    for reason, count_list in termination_reason_list.items():
        count_list = np.array(count_list)
        print("{}: Avg: {}, Std: {}".format(reason,
                                            np.mean(count_list),
                                            np.std(count_list)))

def find_good_high_level_policy(nb_steps=25000, load_weights=False, nb_episodes_for_test=100,
                                visualize=False, tensorboard=False, save_path="./highlevel_weights.h5f"):
    max_num_successes = 0
    current_success = 0
    while current_success < 0.95 * nb_episodes_for_test:
        agent = high_level_policy_training(nb_steps=nb_steps, load_weights=load_weights, visualize=visualize,
                                           tensorboard=tensorboard, testing=False)
        options.set_controller_policy(agent.predict)

        current_success, termination_reason_counter = agent.test_model(options, nb_episodes=nb_episodes_for_test, visualize=visualize)

        if current_success > max_num_successes:
            os.rename(save_path, "highlevel_weights_{}.h5f".format(current_success))
            max_num_successes = current_success

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train",
        help="Train a high level policy with default settings. Always saved in root folder. Always tests after training",
        action="store_true")
    parser.add_argument(
        "--test",
        help="Test a saved high level policy. Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
        action="store_true")
    parser.add_argument(
        "--evaluate",
        help="Evaluate a saved high level policy over n trials. "
             "Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
        action="store_true")
    parser.add_argument(
        "--saved_policy_in_root",
        help="Use saved policies in root of project rather than backends/trained_policies/highlevel/",
        action="store_true")
    parser.add_argument(
        "--load_weights",
        help="Load a saved policy from root folder first before training",
        action="store_true")
    parser.add_argument(
        "--tensorboard",
        help="Use tensorboard while training",
        action="store_true")
    parser.add_argument(
        "--visualize",
        help="Visualize the training. Testing is always visualized. Evaluation is not visualized by default",
        action="store_true")
    parser.add_argument(
        "--nb_steps",
        help="Number of steps to train for. Default is 25000", default=25000, type=int)
    parser.add_argument(
        "--nb_episodes_for_test",
        help="Number of episodes to test/evaluate. Default is 20", default=20, type=int)
    parser.add_argument(
        "--nb_trials",
        help="Number of trials to evaluate. Default is 10", default=10, type=int)
    parser.add_argument(
        "--save_file",
        help="filename to save/load the trained policy. Location is as specified by --saved_policy_in_root",
        default="highlevel_weights.h5f")

    args = parser.parse_args()

    # load options graph
    options = OptionsGraph("config.json", SimpleIntersectionEnv, randomize_special_scenarios=True)
    options.load_trained_low_level_policies()

    if args.train:
        high_level_policy_training(nb_steps=args.nb_steps, load_weights=args.load_weights, save_path=args.save_file,
                                   tensorboard=args.tensorboard, nb_episodes_for_test=args.nb_episodes_for_test,
                                   visualize=args.visualize)

    if args.test:
        high_level_policy_testing(visualize=True, nb_episodes_for_test=args.nb_episodes_for_test,
                                  pretrained=not args.saved_policy_in_root, trained_agent_file=args.save_file)

    if args.evaluate:
        evaluate_high_level_policy(visualize=args.visualize, nb_episodes_for_test=args.nb_episodes_for_test,
                                   pretrained=not args.saved_policy_in_root, trained_agent_file=args.save_file,
                                   nb_trials=args.nb_trials)