low_level_policy_main.py 6.77 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.kerasrl_learner import DDPGLearner

import argparse

# TODO: make a separate file for this function.
def low_level_policy_training(maneuver, nb_steps, RL_method='DDPG',
                              load_weights=False, training=True, testing=True,
                              visualize=False, nb_episodes_for_test=10, tensorboard=False):
    """
    Do RL of the low-level policy of the given maneuver and test it.
    Args:
     maneuver: the name of the maneuver defined in config.json (e.g., 'default').
     nb_steps: the number of steps to perform RL.
     RL_method: either DDPG or PPO2.
     load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs).
     training: True to enable training.
     testing: True to enable testing.
     visualize: True to see the graphical outputs during training.
     nb_episodes_for_test: the number of episodes for testing.
    """
    # initialize the numpy random number generator
    np.random.seed()

    if not (isinstance(load_weights, bool) and isinstance(training, bool) and isinstance(testing, bool)):
        raise ValueError("Type error: the variable has to be boolean.")

    if not load_weights and not training:
        raise ValueError("Both load_weights and training are False: no learning and no loading weights.")

    if not isinstance(maneuver, str):
        raise ValueError("maneuver param has to be a string.")

    if not isinstance(nb_steps, int) or nb_steps <= 0:
        raise ValueError("nb_steps has to be a positive number.")

    if RL_method not in ['DDPG', 'PPO2']:
        raise ValueError("Unsupported RL method.")

    # load options graph
    global options
    options.set_current_node(maneuver)
    options.current_node.reset()

    # TODO: add PPO2 case.
    # Use this code when you train a specific maneuver for the first time.
    agent = DDPGLearner(input_shape=(options.current_node.get_reduced_feature_length(),),
                        nb_actions=2, gamma=0.99,
                        nb_steps_warmup_critic=200,
                        nb_steps_warmup_actor=200,
                        lr=1e-3)

    if load_weights:
        agent.load_model(maneuver + "_weights.h5f")

    if training:
        agent.train(options.current_node, nb_steps=nb_steps, visualize=visualize, verbose=1,
                    log_interval=nb_steps/4, tensorboard=tensorboard)
        agent.save_model(maneuver + "_weights.h5f")  # Save the NN weights for reloading them in the future.

    if testing:
        # TODO: the graphical window is not closed before completing the test.
        options.current_node.learning_mode = 'testing'
        agent.test_model(options.current_node, nb_episodes=nb_episodes_for_test)


def low_level_policy_testing(maneuver, pretrained=False, nb_episodes_for_test=20):

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    global options
    options.set_current_node(maneuver)
    options.current_node.reset()

    agent = DDPGLearner(input_shape=(options.current_node.get_reduced_feature_length(),),
                        nb_actions=2, gamma=0.99,
                        nb_steps_warmup_critic=200,
                        nb_steps_warmup_actor=200,
                        lr=1e-3)

    if pretrained:
        agent.load_model("backends/trained_policies/" + maneuver + "/" + maneuver + "_weights.h5f")
    else:
        agent.load_model(maneuver + "_weights.h5f")

    options.current_node.learning_mode = 'testing'
    agent.test_model(options.current_node, nb_episodes=nb_episodes_for_test)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train",
        help="Train a high level policy with default settings. Always saved in root folder. Always tests after training",
        action="store_true")
    parser.add_argument(
        "--option",
        help="the option to train. Eg. stop, keeplane, wait, changelane, follow. If not defined, trains all options")
    parser.add_argument(
        "--test",
Unknown's avatar
Unknown committed
105
        help="Test a saved high level policy. Uses saved policy in backends/trained_policies/OPTION_NAME/ by default",
Aravind Bk's avatar
Aravind Bk committed
106 107 108
        action="store_true")
    parser.add_argument(
        "--saved_policy_in_root",
Unknown's avatar
Unknown committed
109
        help="Use saved policies in root of project rather than backends/trained_policies/OPTION_NAME",
Aravind Bk's avatar
Aravind Bk committed
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        action="store_true")
    parser.add_argument(
        "--load_weights",
        help="Load a saved policy first before training",
        action="store_true")
    parser.add_argument(
        "--tensorboard",
        help="Use tensorboard while training",
        action="store_true")
    parser.add_argument(
        "--visualize",
        help="Visualize the training. Testing is always visualized.",
        action="store_true")
    parser.add_argument(
        "--nb_steps",
        help="Number of steps to train for. Default is 100000", default=100000, type=int)
    parser.add_argument(
        "--nb_episodes_for_test",
        help="Number of episodes to test. Default is 20", default=20, type=int)

    args = parser.parse_args()

    options = OptionsGraph("config.json", SimpleIntersectionEnv)

    # The experiments of the low-level training can be repeated roughly by executing the code
    # (with np_steps=200000 for the better result(s))
    if args.train:
        if args.option:
            print ("Training {} maneuver...".format(args.option))
            low_level_policy_training(args.option, load_weights=args.load_weights, nb_steps=args.nb_steps,
                                      nb_episodes_for_test=args.nb_episodes_for_test, visualize=args.visualize,
                                      tensorboard=args.tensorboard)
        else:
            for option_key in options.maneuvers.keys():
Unknown's avatar
Unknown committed
144
                print("Training {} maneuver...".format(option_key))
Aravind Bk's avatar
Aravind Bk committed
145 146 147 148 149 150
                low_level_policy_training(option_key, load_weights=args.load_weights, nb_steps=args.nb_steps,
                                          nb_episodes_for_test=args.nb_episodes_for_test, visualize=args.visualize,
                                          tensorboard=args.tensorboard)

    if args.test:
        if args.option:
Unknown's avatar
Unknown committed
151
            print("Testing {} maneuver...".format(args.option))
Aravind Bk's avatar
Aravind Bk committed
152 153 154 155
            low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
                                     nb_episodes_for_test=args.nb_episodes_for_test)
        else:
            for option_key in options.maneuvers.keys():
Unknown's avatar
Unknown committed
156
                print("Testing {} maneuver...".format(option_key))
Aravind Bk's avatar
Aravind Bk committed
157 158
                low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
                                         nb_episodes_for_test=args.nb_episodes_for_test)