low_level_policy_main.py 6.57 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.kerasrl_learner import DDPGLearner

import argparse

# TODO: make a separate file for this function.
def low_level_policy_training(maneuver, nb_steps, RL_method='DDPG',
                              load_weights=False, training=True, testing=True,
                              visualize=False, nb_episodes_for_test=10, tensorboard=False):
    """
    Do RL of the low-level policy of the given maneuver and test it.
    Args:
     maneuver: the name of the maneuver defined in config.json (e.g., 'default').
     nb_steps: the number of steps to perform RL.
     RL_method: either DDPG or PPO2.
     load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs).
     training: True to enable training.
     testing: True to enable testing.
     visualize: True to see the graphical outputs during training.
     nb_episodes_for_test: the number of episodes for testing.
    """
    # initialize the numpy random number generator
    np.random.seed()

    if not (isinstance(load_weights, bool) and isinstance(training, bool) and isinstance(testing, bool)):
        raise ValueError("Type error: the variable has to be boolean.")

    if not load_weights and not training:
        raise ValueError("Both load_weights and training are False: no learning and no loading weights.")

    if not isinstance(maneuver, str):
        raise ValueError("maneuver param has to be a string.")

    if not isinstance(nb_steps, int) or nb_steps <= 0:
        raise ValueError("nb_steps has to be a positive number.")

    if RL_method not in ['DDPG', 'PPO2']:
        raise ValueError("Unsupported RL method.")

    # load options graph
    global options
    options.set_current_node(maneuver)
    options.current_node.reset()

    # TODO: add PPO2 case.
    # Use this code when you train a specific maneuver for the first time.
    agent = DDPGLearner(input_shape=(options.current_node.get_reduced_feature_length(),),
                        nb_actions=2, gamma=0.99,
                        nb_steps_warmup_critic=200,
                        nb_steps_warmup_actor=200,
                        lr=1e-3)

    if load_weights:
        agent.load_model(maneuver + "_weights.h5f")

    if training:
        agent.train(options.current_node, nb_steps=nb_steps, visualize=visualize, verbose=1,
                    log_interval=nb_steps/4, tensorboard=tensorboard)
        agent.save_model(maneuver + "_weights.h5f")  # Save the NN weights for reloading them in the future.

    if testing:
        # TODO: the graphical window is not closed before completing the test.
        options.current_node.learning_mode = 'testing'
        agent.test_model(options.current_node, nb_episodes=nb_episodes_for_test)


def low_level_policy_testing(maneuver, pretrained=False, nb_episodes_for_test=20):

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    global options
    options.set_current_node(maneuver)
    options.current_node.reset()

    agent = DDPGLearner(input_shape=(options.current_node.get_reduced_feature_length(),),
                        nb_actions=2, gamma=0.99,
                        nb_steps_warmup_critic=200,
                        nb_steps_warmup_actor=200,
                        lr=1e-3)

    if pretrained:
        agent.load_model("backends/trained_policies/" + maneuver + "/" + maneuver + "_weights.h5f")
    else:
        agent.load_model(maneuver + "_weights.h5f")

    options.current_node.learning_mode = 'testing'
    agent.test_model(options.current_node, nb_episodes=nb_episodes_for_test)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train",
        help="Train a high level policy with default settings. Always saved in root folder. Always tests after training",
        action="store_true")
    parser.add_argument(
        "--option",
        help="the option to train. Eg. stop, keeplane, wait, changelane, follow. If not defined, trains all options")
    parser.add_argument(
        "--test",
        help="Test a saved high level policy. Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
        action="store_true")
    parser.add_argument(
        "--saved_policy_in_root",
        help="Use saved policies in root of project rather than backends/trained_policies/highlevel/",
        action="store_true")
    parser.add_argument(
        "--load_weights",
        help="Load a saved policy first before training",
        action="store_true")
    parser.add_argument(
        "--tensorboard",
        help="Use tensorboard while training",
        action="store_true")
    parser.add_argument(
        "--visualize",
        help="Visualize the training. Testing is always visualized.",
        action="store_true")
    parser.add_argument(
        "--nb_steps",
        help="Number of steps to train for. Default is 100000", default=100000, type=int)
    parser.add_argument(
        "--nb_episodes_for_test",
        help="Number of episodes to test. Default is 20", default=20, type=int)

    args = parser.parse_args()

    options = OptionsGraph("config.json", SimpleIntersectionEnv)

    # The experiments of the low-level training can be repeated roughly by executing the code
    # (with np_steps=200000 for the better result(s))
    if args.train:
        if args.option:
            print ("Training {} maneuver...".format(args.option))
            low_level_policy_training(args.option, load_weights=args.load_weights, nb_steps=args.nb_steps,
                                      nb_episodes_for_test=args.nb_episodes_for_test, visualize=args.visualize,
                                      tensorboard=args.tensorboard)
        else:
            for option_key in options.maneuvers.keys():
                low_level_policy_training(option_key, load_weights=args.load_weights, nb_steps=args.nb_steps,
                                          nb_episodes_for_test=args.nb_episodes_for_test, visualize=args.visualize,
                                          tensorboard=args.tensorboard)

    if args.test:
        if args.option:
            low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
                                     nb_episodes_for_test=args.nb_episodes_for_test)
        else:
            for option_key in options.maneuvers.keys():
                low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
                                         nb_episodes_for_test=args.nb_episodes_for_test)