From 382065ac131b5ed636a3e4c16eaa044e8b502670 Mon Sep 17 00:00:00 2001 From: Unknown <aravindbk92@gmail.com> Date: Sun, 18 Nov 2018 02:28:42 -0500 Subject: [PATCH] Updated readme and running --- README.txt | 21 +++++++--- high_level_policy_main.py | 2 +- low_level_policy_main.py | 7 +++- mcts.py | 85 +++++++++++++++++++++++++++------------ 4 files changed, 80 insertions(+), 35 deletions(-) diff --git a/README.txt b/README.txt index a82c360..9382cb1 100644 --- a/README.txt +++ b/README.txt @@ -29,14 +29,23 @@ These are the minimum steps required to replicate the results for simple_interse * Run `./scripts/install_dependencies.sh` to install python dependencies. * Low-level policies: - * To train all low-level policies from scratch: `python3 low_level_policy_main.py --train` - * To train a single low-level, for example wait: `python3 low_level_policy_main.py --option=wait --train` - * To test these trained low-level policies: `python3 low_level_policy_main.py --test --saved_policy_in_root` - * To test one of these trained low-level policies, for example wait: `python3 low_level_policy_main.py --option=wait --test --saved_policy_in_root` + * You can choose to train and test all the maneuvers. But this may take some time and is not recommended. + * To train all low-level policies from scratch: `python3 low_level_policy_main.py --train`. This may take some time. + * To test all these trained low-level policies: `python3 low_level_policy_main.py --test --saved_policy_in_root`. + * Make sure the training is fully complete before running above test. + * It is easier to verify few of the maneuvers using below commands: + * To train a single low-level, for example wait: `python3 low_level_policy_main.py --option=wait --train`. + * To test one of these trained low-level policies, for example wait: `python3 low_level_policy_main.py --option=wait --test --saved_policy_in_root` + * Available maneuvers are: wait, changelane, stop, keeplane, follow + * These results are visually evaluated. * High-level policy: * To train high-level policy from scratch using the given low-level policies: `python3 high_level_policy_main.py --train` - * To evaluate this trained high-level policy: `python3 high_level_policy_main.py --evaluate --saved_policy_in_root` -* To run MCTS using the high-level policy: `python3 mcts.py` + * To evaluate this trained high-level policy: `python3 high_level_policy_main.py --evaluate --saved_policy_in_root`. + * The success average and standard deviation corresponds to the result from high-level policy experiments. +* To run MCTS using the high-level policy: + * To obtain a probabilites tree and save it: `python3 mcts.py --train` + * To evaluate using this saved tree: `python3 mcts.py --evaluate --saved_policy_in_root`. + * The success average and standard deviation corresponds to the results from MCTS experiments. Coding Standards ================ diff --git a/high_level_policy_main.py b/high_level_policy_main.py index a9b4108..e22ad07 100644 --- a/high_level_policy_main.py +++ b/high_level_policy_main.py @@ -155,7 +155,7 @@ if __name__ == "__main__": help="Number of steps to train for. Default is 25000", default=25000, type=int) parser.add_argument( "--nb_episodes_for_test", - help="Number of episodes to test/evaluate. Default is 20", default=20, type=int) + help="Number of episodes to test/evaluate. Default is 100", default=100, type=int) parser.add_argument( "--nb_trials", help="Number of trials to evaluate. Default is 10", default=10, type=int) diff --git a/low_level_policy_main.py b/low_level_policy_main.py index 291ea82..d4a129a 100644 --- a/low_level_policy_main.py +++ b/low_level_policy_main.py @@ -102,11 +102,11 @@ if __name__ == "__main__": help="the option to train. Eg. stop, keeplane, wait, changelane, follow. If not defined, trains all options") parser.add_argument( "--test", - help="Test a saved high level policy. Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default", + help="Test a saved high level policy. Uses saved policy in backends/trained_policies/OPTION_NAME/ by default", action="store_true") parser.add_argument( "--saved_policy_in_root", - help="Use saved policies in root of project rather than backends/trained_policies/highlevel/", + help="Use saved policies in root of project rather than backends/trained_policies/OPTION_NAME", action="store_true") parser.add_argument( "--load_weights", @@ -141,15 +141,18 @@ if __name__ == "__main__": tensorboard=args.tensorboard) else: for option_key in options.maneuvers.keys(): + print("Training {} maneuver...".format(option_key)) low_level_policy_training(option_key, load_weights=args.load_weights, nb_steps=args.nb_steps, nb_episodes_for_test=args.nb_episodes_for_test, visualize=args.visualize, tensorboard=args.tensorboard) if args.test: if args.option: + print("Testing {} maneuver...".format(args.option)) low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root, nb_episodes_for_test=args.nb_episodes_for_test) else: for option_key in options.maneuvers.keys(): + print("Testing {} maneuver...".format(option_key)) low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root, nb_episodes_for_test=args.nb_episodes_for_test) \ No newline at end of file diff --git a/mcts.py b/mcts.py index 95ef93f..9bd2abb 100644 --- a/mcts.py +++ b/mcts.py @@ -1,11 +1,9 @@ from env.simple_intersection import SimpleIntersectionEnv -from env.simple_intersection.constants import * from options.options_loader import OptionsGraph from backends import DDPGLearner, DQNLearner, MCTSLearner -import pickle -import tqdm import numpy as np import tqdm +import argparse import sys @@ -28,7 +26,7 @@ class Logger(object): sys.stdout = Logger() # TODO: make a separate file for this function. -def mcts_training(nb_traversals, save_every=20, visualize=False): +def mcts_training(nb_traversals, save_every=20, visualize=False, load_saved=False, save_file="mcts.pickle"): """ Do RL of the low-level policy of the given maneuver and test it. Args: @@ -50,7 +48,9 @@ def mcts_training(nb_traversals, save_every=20, visualize=False): agent.load_model("backends/trained_policies/highlevel/highlevel_weights.h5f") options.set_controller_args(predictor = agent.get_softq_value_using_option_alias) options.controller.max_depth = 20 - #options.controller.load_model('backends/trained_policies/mcts/mcts.pickle') + + if load_saved: + options.controller.load_model(save_file) total_epochs = nb_traversals//save_every trav_num = 1 @@ -62,17 +62,16 @@ def mcts_training(nb_traversals, save_every=20, visualize=False): options.controller.curr_node_num = 0 init_obs = options.reset() v, all_ep_R = options.controller.traverse(init_obs, visualize=visualize) - # print('Traversal %d: V = %f' % (num_traversal, v)) - # print('Overall Reward: %f\n' % all_ep_R) + last_rewards += [all_ep_R] trav_num += 1 - options.controller.save_model('mcts_%d.pickle' % (num_epoch)) + options.controller.save_model(save_file) success = lambda x: x > 50 success_rate = np.sum(list(map(success, last_rewards)))/(len(last_rewards)*1.0) print('success rate: %f' % success_rate) print('Average Reward (%d-%d): %f\n' % (beg_trav_num, trav_num-1, np.mean(last_rewards))) -def mcts_evaluation(nb_traversals, num_trials=5, visualize=False): +def mcts_evaluation(nb_traversals, num_trials=5, visualize=False, save_file="mcts.pickle", pretrained=False): """ Do RL of the low-level policy of the given maneuver and test it. Args: @@ -95,11 +94,14 @@ def mcts_evaluation(nb_traversals, num_trials=5, visualize=False): options.set_controller_args(predictor=agent.get_softq_value_using_option_alias) options.controller.max_depth = 20 + if pretrained: + save_file = "backends/trained_policies/mcts/" + save_file + success_list = [] print('Total number of trials = %d' % num_trials) for trial in range(num_trials): num_successes = 0 - options.controller.load_model('backends/trained_policies/mcts/mcts.pickle') + options.controller.load_model(save_file) for num_traversal in tqdm.tqdm(range(nb_traversals)): options.controller.curr_node_num = 0 init_obs = options.reset() @@ -222,20 +224,51 @@ def evaluate_online_mcts(nb_episodes=20, nb_trials=5): np.mean(count_list), np.std(count_list))) -def mcts_visualize(file_name): - with open(file_name, 'rb') as handle: - to_restore = pickle.load(handle) - # TR = to_restore['TR'] - # M = to_restore['M'] - # for key, val in TR.items(): - # print('%s: %f, count = %d' % (key, val/M[key], M[key])) - print(len(to_restore['nodes'])) - if __name__ == "__main__": - - mcts_training(nb_traversals=10000, save_every=1000, visualize=False) - # mcts_evaluation(nb_traversals=100, num_trials=10, visualize=False) - # for num in range(100): mcts_visualize('timeout_inf_save100/mcts_%d.pickle' % num) - # mcts_visualize('mcts.pickle') - #online_mcts(10) - # evaluate_online_mcts(nb_episodes=20,nb_trials=5) + parser = argparse.ArgumentParser() + parser.add_argument( + "--train", + help="Train an offline mcts with default settings. Always saved in root folder.", + action="store_true") + parser.add_argument( + "--evaluate", + help="Evaluate over n trials. " + "Uses backends/trained_policies/mcts/mcts.pickle by default", + action="store_true") + parser.add_argument( + "--saved_policy_in_root", + help="Use saved policies in root of project rather than backends/trained_policies/mcts/", + action="store_true") + parser.add_argument( + "--load_saved", + help="Load a saved policy from root folder first before training", + action="store_true") + parser.add_argument( + "--visualize", + help="Visualize the training. Testing is always visualized. Evaluation is not visualized by default", + action="store_true") + parser.add_argument( + "--nb_traversals", + help="Number of traversals to perform. Default is 1000", default=1000, type=int) + parser.add_argument( + "--save_every", + help="Saves every n traversals. Saves in root by default. Default is 500", default=500, type=int) + parser.add_argument( + "--nb_traversals_for_test", + help="Number of episodes to evaluate. Default is 100", default=100, type=int) + parser.add_argument( + "--nb_trials", + help="Number of trials to evaluate. Default is 10", default=10, type=int) + parser.add_argument( + "--save_file", + help="filename to save/load the trained policy. Location is as specified by --saved_policy_in_root. Default name is mcts.pickle", + default="mcts.pickle") + + args = parser.parse_args() + + if args.train: + mcts_training(nb_traversals=args.nb_traversals, save_every=args.save_every, visualize=args.visualize, + load_saved=args.load_saved, save_file=args.save_file) + if args.evaluate: + mcts_evaluation(nb_traversals=args.nb_traversals_for_test, num_trials=args.nb_trials, visualize=args.visualize, + pretrained=not args.saved_policy_in_root, save_file=args.save_file) \ No newline at end of file -- GitLab