low_level_policy_main.py 7.21 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.kerasrl_learner import DDPGLearner

import argparse

Ashish Gaurav's avatar
Ashish Gaurav committed
8

Aravind Bk's avatar
Aravind Bk committed
9
# TODO: make a separate file for this function.
Ashish Gaurav's avatar
Ashish Gaurav committed
10 11 12 13 14 15 16 17 18
def low_level_policy_training(maneuver,
                              nb_steps,
                              RL_method='DDPG',
                              load_weights=False,
                              training=True,
                              testing=True,
                              visualize=False,
                              nb_episodes_for_test=10,
                              tensorboard=False):
Aravind Bk's avatar
Aravind Bk committed
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
    """
    Do RL of the low-level policy of the given maneuver and test it.
    Args:
     maneuver: the name of the maneuver defined in config.json (e.g., 'default').
     nb_steps: the number of steps to perform RL.
     RL_method: either DDPG or PPO2.
     load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs).
     training: True to enable training.
     testing: True to enable testing.
     visualize: True to see the graphical outputs during training.
     nb_episodes_for_test: the number of episodes for testing.
    """
    # initialize the numpy random number generator
    np.random.seed()

Ashish Gaurav's avatar
Ashish Gaurav committed
34 35
    if not (isinstance(load_weights, bool) and isinstance(training, bool)
            and isinstance(testing, bool)):
Aravind Bk's avatar
Aravind Bk committed
36 37 38
        raise ValueError("Type error: the variable has to be boolean.")

    if not load_weights and not training:
Ashish Gaurav's avatar
Ashish Gaurav committed
39 40 41
        raise ValueError(
            "Both load_weights and training are False: no learning and no loading weights."
        )
Aravind Bk's avatar
Aravind Bk committed
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

    if not isinstance(maneuver, str):
        raise ValueError("maneuver param has to be a string.")

    if not isinstance(nb_steps, int) or nb_steps <= 0:
        raise ValueError("nb_steps has to be a positive number.")

    if RL_method not in ['DDPG', 'PPO2']:
        raise ValueError("Unsupported RL method.")

    # load options graph
    global options
    options.set_current_node(maneuver)
    options.current_node.reset()

    # TODO: add PPO2 case.
    # Use this code when you train a specific maneuver for the first time.
Ashish Gaurav's avatar
Ashish Gaurav committed
59 60 61 62 63 64 65
    agent = DDPGLearner(
        input_shape=(options.current_node.get_reduced_feature_length(), ),
        nb_actions=2,
        gamma=0.99,
        nb_steps_warmup_critic=200,
        nb_steps_warmup_actor=200,
        lr=1e-3)
Aravind Bk's avatar
Aravind Bk committed
66 67 68 69 70

    if load_weights:
        agent.load_model(maneuver + "_weights.h5f")

    if training:
Ashish Gaurav's avatar
Ashish Gaurav committed
71 72 73 74 75 76 77 78 79 80
        agent.train(
            options.current_node,
            nb_steps=nb_steps,
            visualize=visualize,
            verbose=1,
            log_interval=nb_steps / 4,
            tensorboard=tensorboard)
        agent.save_model(
            maneuver + "_weights.h5f"
        )  # Save the NN weights for reloading them in the future.
Aravind Bk's avatar
Aravind Bk committed
81 82 83 84

    if testing:
        # TODO: the graphical window is not closed before completing the test.
        options.current_node.learning_mode = 'testing'
Ashish Gaurav's avatar
Ashish Gaurav committed
85 86
        agent.test_model(
            options.current_node, nb_episodes=nb_episodes_for_test)
Aravind Bk's avatar
Aravind Bk committed
87 88


Ashish Gaurav's avatar
Ashish Gaurav committed
89 90 91
def low_level_policy_testing(maneuver,
                             pretrained=False,
                             nb_episodes_for_test=20):
Aravind Bk's avatar
Aravind Bk committed
92 93 94 95 96 97 98 99 100

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    global options
    options.set_current_node(maneuver)
    options.current_node.reset()

Ashish Gaurav's avatar
Ashish Gaurav committed
101 102 103 104 105 106 107
    agent = DDPGLearner(
        input_shape=(options.current_node.get_reduced_feature_length(), ),
        nb_actions=2,
        gamma=0.99,
        nb_steps_warmup_critic=200,
        nb_steps_warmup_actor=200,
        lr=1e-3)
Aravind Bk's avatar
Aravind Bk committed
108 109

    if pretrained:
Ashish Gaurav's avatar
Ashish Gaurav committed
110 111
        agent.load_model("backends/trained_policies/" + maneuver + "/" +
                         maneuver + "_weights.h5f")
Aravind Bk's avatar
Aravind Bk committed
112 113 114 115 116 117 118 119 120 121 122
    else:
        agent.load_model(maneuver + "_weights.h5f")

    options.current_node.learning_mode = 'testing'
    agent.test_model(options.current_node, nb_episodes=nb_episodes_for_test)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train",
Ashish Gaurav's avatar
Ashish Gaurav committed
123 124
        help=
        "Train a high level policy with default settings. Always saved in root folder. Always tests after training",
Aravind Bk's avatar
Aravind Bk committed
125 126 127
        action="store_true")
    parser.add_argument(
        "--option",
Ashish Gaurav's avatar
Ashish Gaurav committed
128 129 130
        help=
        "the option to train. Eg. stop, keeplane, wait, changelane, follow. If not defined, trains all options"
    )
Aravind Bk's avatar
Aravind Bk committed
131 132
    parser.add_argument(
        "--test",
Ashish Gaurav's avatar
Ashish Gaurav committed
133 134
        help=
        "Test a saved high level policy. Uses saved policy in backends/trained_policies/OPTION_NAME/ by default",
Aravind Bk's avatar
Aravind Bk committed
135 136 137
        action="store_true")
    parser.add_argument(
        "--saved_policy_in_root",
Ashish Gaurav's avatar
Ashish Gaurav committed
138 139
        help=
        "Use saved policies in root of project rather than backends/trained_policies/OPTION_NAME",
Aravind Bk's avatar
Aravind Bk committed
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
        action="store_true")
    parser.add_argument(
        "--load_weights",
        help="Load a saved policy first before training",
        action="store_true")
    parser.add_argument(
        "--tensorboard",
        help="Use tensorboard while training",
        action="store_true")
    parser.add_argument(
        "--visualize",
        help="Visualize the training. Testing is always visualized.",
        action="store_true")
    parser.add_argument(
        "--nb_steps",
Ashish Gaurav's avatar
Ashish Gaurav committed
155 156 157
        help="Number of steps to train for. Default is 100000",
        default=100000,
        type=int)
Aravind Bk's avatar
Aravind Bk committed
158 159
    parser.add_argument(
        "--nb_episodes_for_test",
Ashish Gaurav's avatar
Ashish Gaurav committed
160 161 162
        help="Number of episodes to test. Default is 20",
        default=20,
        type=int)
Aravind Bk's avatar
Aravind Bk committed
163 164 165 166 167 168 169 170 171

    args = parser.parse_args()

    options = OptionsGraph("config.json", SimpleIntersectionEnv)

    # The experiments of the low-level training can be repeated roughly by executing the code
    # (with np_steps=200000 for the better result(s))
    if args.train:
        if args.option:
Ashish Gaurav's avatar
Ashish Gaurav committed
172 173 174 175 176 177 178 179
            print("Training {} maneuver...".format(args.option))
            low_level_policy_training(
                args.option,
                load_weights=args.load_weights,
                nb_steps=args.nb_steps,
                nb_episodes_for_test=args.nb_episodes_for_test,
                visualize=args.visualize,
                tensorboard=args.tensorboard)
Aravind Bk's avatar
Aravind Bk committed
180 181
        else:
            for option_key in options.maneuvers.keys():
Unknown's avatar
Unknown committed
182
                print("Training {} maneuver...".format(option_key))
Ashish Gaurav's avatar
Ashish Gaurav committed
183 184 185 186 187 188 189
                low_level_policy_training(
                    option_key,
                    load_weights=args.load_weights,
                    nb_steps=args.nb_steps,
                    nb_episodes_for_test=args.nb_episodes_for_test,
                    visualize=args.visualize,
                    tensorboard=args.tensorboard)
Aravind Bk's avatar
Aravind Bk committed
190 191 192

    if args.test:
        if args.option:
Unknown's avatar
Unknown committed
193
            print("Testing {} maneuver...".format(args.option))
Ashish Gaurav's avatar
Ashish Gaurav committed
194 195 196 197
            low_level_policy_testing(
                args.option,
                pretrained=not args.saved_policy_in_root,
                nb_episodes_for_test=args.nb_episodes_for_test)
Aravind Bk's avatar
Aravind Bk committed
198 199
        else:
            for option_key in options.maneuvers.keys():
Unknown's avatar
Unknown committed
200
                print("Testing {} maneuver...".format(option_key))
Ashish Gaurav's avatar
Ashish Gaurav committed
201 202 203 204
                low_level_policy_testing(
                    args.option,
                    pretrained=not args.saved_policy_in_root,
                    nb_episodes_for_test=args.nb_episodes_for_test)