low_level_policy_main.py 10.4 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.kerasrl_learner import DDPGLearner
5
from rl.callbacks import Callback
Aravind Bk's avatar
Aravind Bk committed
6 7 8

import argparse

Ashish Gaurav's avatar
Ashish Gaurav committed
9

10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
class ManeuverEvaluateCallback(Callback):
    def __init__(self, maneuver):
        self.low_reward_count = 0
        self.mid_reward_count = 0
        self.high_reward_count = 0
        self.maneuver = maneuver
        super().__init__()

    def on_episode_end(self, episode, logs={}):
        """Called at end of each episode"""
        if logs['episode_reward'] < -150:
            self.low_reward_count += 1
        elif logs['episode_reward'] > 150:
            self.high_reward_count += 1
        else:
            self.mid_reward_count += 1

        super().on_episode_end(episode, logs)

    def on_train_end(self, logs=None):
        print("\nThe total # of episode: " + str(self.low_reward_count +
                                          self.mid_reward_count +
                                          self.high_reward_count))
        print(" # of episode with reward < -150: " + str(self.low_reward_count) +
              "\n # of episode with -150 <= reward <= 150: " + str(self.mid_reward_count) +
              "\n # of episode with reward > 150: " + str(self.high_reward_count))

        success_count = \
            self.mid_reward_count if self.maneuver == 'follow' else \
                self.high_reward_count

        print("\n # of success episode: " + str(success_count) + '\n')


Aravind Bk's avatar
Aravind Bk committed
44
# TODO: make a separate file for this function.
Ashish Gaurav's avatar
Ashish Gaurav committed
45 46 47 48 49 50 51 52
def low_level_policy_training(maneuver,
                              nb_steps,
                              RL_method='DDPG',
                              load_weights=False,
                              training=True,
                              testing=True,
                              visualize=False,
                              nb_episodes_for_test=10,
53 54
                              tensorboard=False,
                              without_ltl=False):
Ashish Gaurav's avatar
Ashish Gaurav committed
55 56
    """Do RL of the low-level policy of the given maneuver and test it.

Aravind Bk's avatar
Aravind Bk committed
57 58 59 60 61 62 63 64 65 66 67 68 69
    Args:
     maneuver: the name of the maneuver defined in config.json (e.g., 'default').
     nb_steps: the number of steps to perform RL.
     RL_method: either DDPG or PPO2.
     load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs).
     training: True to enable training.
     testing: True to enable testing.
     visualize: True to see the graphical outputs during training.
     nb_episodes_for_test: the number of episodes for testing.
    """
    # initialize the numpy random number generator
    np.random.seed()

Ashish Gaurav's avatar
Ashish Gaurav committed
70 71
    if not (isinstance(load_weights, bool) and isinstance(training, bool)
            and isinstance(testing, bool)):
Aravind Bk's avatar
Aravind Bk committed
72 73 74
        raise ValueError("Type error: the variable has to be boolean.")

    if not load_weights and not training:
Ashish Gaurav's avatar
Ashish Gaurav committed
75 76 77
        raise ValueError(
            "Both load_weights and training are False: no learning and no loading weights."
        )
Aravind Bk's avatar
Aravind Bk committed
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

    if not isinstance(maneuver, str):
        raise ValueError("maneuver param has to be a string.")

    if not isinstance(nb_steps, int) or nb_steps <= 0:
        raise ValueError("nb_steps has to be a positive number.")

    if RL_method not in ['DDPG', 'PPO2']:
        raise ValueError("Unsupported RL method.")

    # load options graph
    global options
    options.set_current_node(maneuver)
    options.current_node.reset()

93 94 95 96 97 98
    # TODO: make this into a training/testing flag in optionsloader?
    if without_ltl:
        options.current_node._enable_low_level_training_properties = False
    else:
        options.current_node._enable_low_level_training_properties = True

Aravind Bk's avatar
Aravind Bk committed
99 100
    # TODO: add PPO2 case.
    # Use this code when you train a specific maneuver for the first time.
Ashish Gaurav's avatar
Ashish Gaurav committed
101 102 103 104 105 106 107
    agent = DDPGLearner(
        input_shape=(options.current_node.get_reduced_feature_length(), ),
        nb_actions=2,
        gamma=0.99,
        nb_steps_warmup_critic=200,
        nb_steps_warmup_actor=200,
        lr=1e-3)
Aravind Bk's avatar
Aravind Bk committed
108 109 110 111 112

    if load_weights:
        agent.load_model(maneuver + "_weights.h5f")

    if training:
Ashish Gaurav's avatar
Ashish Gaurav committed
113 114 115 116 117 118 119 120 121 122
        agent.train(
            options.current_node,
            nb_steps=nb_steps,
            visualize=visualize,
            verbose=1,
            log_interval=nb_steps / 4,
            tensorboard=tensorboard)
        agent.save_model(
            maneuver + "_weights.h5f"
        )  # Save the NN weights for reloading them in the future.
Aravind Bk's avatar
Aravind Bk committed
123 124 125 126

    if testing:
        # TODO: the graphical window is not closed before completing the test.
        options.current_node.learning_mode = 'testing'
Ashish Gaurav's avatar
Ashish Gaurav committed
127 128
        agent.test_model(
            options.current_node, nb_episodes=nb_episodes_for_test)
Aravind Bk's avatar
Aravind Bk committed
129 130


Ashish Gaurav's avatar
Ashish Gaurav committed
131 132
def low_level_policy_testing(maneuver,
                             pretrained=False,
133
                             visualize=True,
Ashish Gaurav's avatar
Ashish Gaurav committed
134
                             nb_episodes_for_test=20):
Aravind Bk's avatar
Aravind Bk committed
135 136 137 138 139 140 141 142 143

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    global options
    options.set_current_node(maneuver)
    options.current_node.reset()

Ashish Gaurav's avatar
Ashish Gaurav committed
144 145 146 147 148 149 150
    agent = DDPGLearner(
        input_shape=(options.current_node.get_reduced_feature_length(), ),
        nb_actions=2,
        gamma=0.99,
        nb_steps_warmup_critic=200,
        nb_steps_warmup_actor=200,
        lr=1e-3)
Aravind Bk's avatar
Aravind Bk committed
151 152

    if pretrained:
Ashish Gaurav's avatar
Ashish Gaurav committed
153 154
        agent.load_model("backends/trained_policies/" + maneuver + "/" +
                         maneuver + "_weights.h5f")
Aravind Bk's avatar
Aravind Bk committed
155 156 157 158
    else:
        agent.load_model(maneuver + "_weights.h5f")

    options.current_node.learning_mode = 'testing'
159 160 161 162 163 164 165 166 167 168 169 170 171
    agent.test_model(options.current_node,
                     nb_episodes=nb_episodes_for_test,
                     callbacks=[ManeuverEvaluateCallback(maneuver)],
                     visualize=visualize)


def evaluate_low_level_policy(maneuver,
                              pretrained=False,
                              nb_episodes_for_eval=100):

    low_level_policy_testing(maneuver, pretrained,
                             nb_episodes_for_test=nb_episodes_for_eval,
                             visualize=False)
Aravind Bk's avatar
Aravind Bk committed
172 173 174 175 176 177


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train",
Ashish Gaurav's avatar
Ashish Gaurav committed
178
        help=
179
        "Train a low level policy with default settings. Always saved in root folder. Always tests after training",
Aravind Bk's avatar
Aravind Bk committed
180 181 182
        action="store_true")
    parser.add_argument(
        "--option",
Ashish Gaurav's avatar
Ashish Gaurav committed
183
        help=
184
        "the option to train. Eg. stop, keeplane, wait, changelane, follow. If not defined, trains all of the five options"
Ashish Gaurav's avatar
Ashish Gaurav committed
185
    )
Aravind Bk's avatar
Aravind Bk committed
186 187
    parser.add_argument(
        "--test",
Ashish Gaurav's avatar
Ashish Gaurav committed
188
        help=
189 190 191 192 193 194 195
        "Test a saved low level policy. Uses saved policy in backends/trained_policies/OPTION_NAME/ by default",
        action="store_true")

    parser.add_argument(
        "--without_additional_ltl_properties",
        help=
        "Train a low level policy without additional LTL constraints.",
Aravind Bk's avatar
Aravind Bk committed
196
        action="store_true")
197 198 199 200 201 202 203

    parser.add_argument(
        "--evaluate",
        help="Evaluate a saved low level policy over 100 episodes. "
        "Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
        action="store_true")

Aravind Bk's avatar
Aravind Bk committed
204 205
    parser.add_argument(
        "--saved_policy_in_root",
Ashish Gaurav's avatar
Ashish Gaurav committed
206 207
        help=
        "Use saved policies in root of project rather than backends/trained_policies/OPTION_NAME",
Aravind Bk's avatar
Aravind Bk committed
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
        action="store_true")
    parser.add_argument(
        "--load_weights",
        help="Load a saved policy first before training",
        action="store_true")
    parser.add_argument(
        "--tensorboard",
        help="Use tensorboard while training",
        action="store_true")
    parser.add_argument(
        "--visualize",
        help="Visualize the training. Testing is always visualized.",
        action="store_true")
    parser.add_argument(
        "--nb_steps",
Ashish Gaurav's avatar
Ashish Gaurav committed
223 224 225
        help="Number of steps to train for. Default is 100000",
        default=100000,
        type=int)
Aravind Bk's avatar
Aravind Bk committed
226 227
    parser.add_argument(
        "--nb_episodes_for_test",
228 229
        help="Number of episodes to test. Default is 10",
        default=10,
Ashish Gaurav's avatar
Ashish Gaurav committed
230
        type=int)
Aravind Bk's avatar
Aravind Bk committed
231 232 233 234 235 236 237 238 239

    args = parser.parse_args()

    options = OptionsGraph("config.json", SimpleIntersectionEnv)

    # The experiments of the low-level training can be repeated roughly by executing the code
    # (with np_steps=200000 for the better result(s))
    if args.train:
        if args.option:
Ashish Gaurav's avatar
Ashish Gaurav committed
240 241 242 243 244 245 246
            print("Training {} maneuver...".format(args.option))
            low_level_policy_training(
                args.option,
                load_weights=args.load_weights,
                nb_steps=args.nb_steps,
                nb_episodes_for_test=args.nb_episodes_for_test,
                visualize=args.visualize,
247 248
                tensorboard=args.tensorboard,
                without_ltl=args.without_additional_ltl_properties)
Aravind Bk's avatar
Aravind Bk committed
249 250
        else:
            for option_key in options.maneuvers.keys():
Unknown's avatar
Unknown committed
251
                print("Training {} maneuver...".format(option_key))
Ashish Gaurav's avatar
Ashish Gaurav committed
252 253 254 255 256 257
                low_level_policy_training(
                    option_key,
                    load_weights=args.load_weights,
                    nb_steps=args.nb_steps,
                    nb_episodes_for_test=args.nb_episodes_for_test,
                    visualize=args.visualize,
258 259
                    tensorboard=args.tensorboard,
                    without_ltl=args.without_additional_ltl_properties)
Aravind Bk's avatar
Aravind Bk committed
260 261 262

    if args.test:
        if args.option:
Unknown's avatar
Unknown committed
263
            print("Testing {} maneuver...".format(args.option))
Ashish Gaurav's avatar
Ashish Gaurav committed
264 265 266
            low_level_policy_testing(
                args.option,
                pretrained=not args.saved_policy_in_root,
267
                visualize=True,
Ashish Gaurav's avatar
Ashish Gaurav committed
268
                nb_episodes_for_test=args.nb_episodes_for_test)
Aravind Bk's avatar
Aravind Bk committed
269 270
        else:
            for option_key in options.maneuvers.keys():
Unknown's avatar
Unknown committed
271
                print("Testing {} maneuver...".format(option_key))
Ashish Gaurav's avatar
Ashish Gaurav committed
272
                low_level_policy_testing(
273
                    option_key,
Ashish Gaurav's avatar
Ashish Gaurav committed
274
                    pretrained=not args.saved_policy_in_root,
275
                    visualize=True,
Ashish Gaurav's avatar
Ashish Gaurav committed
276
                    nb_episodes_for_test=args.nb_episodes_for_test)
277 278 279 280 281 282 283 284 285 286 287 288 289

    if args.evaluate:
        if args.option:
            print("Evaluating {} maneuver...".format(args.option))
            evaluate_low_level_policy(
                args.option,
                pretrained=not args.saved_policy_in_root)
        else:
            for option_key in options.maneuvers.keys():
                print("Evaluating {} maneuver...".format(option_key))
                evaluate_low_level_policy(
                    option_key,
                    pretrained=not args.saved_policy_in_root)