From 382065ac131b5ed636a3e4c16eaa044e8b502670 Mon Sep 17 00:00:00 2001
From: Unknown <aravindbk92@gmail.com>
Date: Sun, 18 Nov 2018 02:28:42 -0500
Subject: [PATCH] Updated readme and running

---
 README.txt                | 21 +++++++---
 high_level_policy_main.py |  2 +-
 low_level_policy_main.py  |  7 +++-
 mcts.py                   | 85 +++++++++++++++++++++++++++------------
 4 files changed, 80 insertions(+), 35 deletions(-)

diff --git a/README.txt b/README.txt
index a82c360..9382cb1 100644
--- a/README.txt
+++ b/README.txt
@@ -29,14 +29,23 @@ These are the minimum steps required to replicate the results for simple_interse
 
 * Run `./scripts/install_dependencies.sh` to install python dependencies.
 * Low-level policies:
-    * To train all low-level policies from scratch: `python3 low_level_policy_main.py --train`
-    * To train a single low-level, for example wait: `python3 low_level_policy_main.py --option=wait --train`
-    * To test these trained low-level policies: `python3 low_level_policy_main.py --test --saved_policy_in_root`
-    * To test one of these trained low-level policies, for example wait: `python3 low_level_policy_main.py --option=wait --test --saved_policy_in_root`
+    * You can choose to train and test all the maneuvers. But this may take some time and is not recommended.
+        * To train all low-level policies from scratch: `python3 low_level_policy_main.py --train`. This may take some time.
+        * To test all these trained low-level policies: `python3 low_level_policy_main.py --test --saved_policy_in_root`.
+        * Make sure the training is fully complete before running above test.
+    * It is easier to verify few of the maneuvers using below commands:
+        * To train a single low-level, for example wait: `python3 low_level_policy_main.py --option=wait --train`.
+        * To test one of these trained low-level policies, for example wait: `python3 low_level_policy_main.py --option=wait --test --saved_policy_in_root`
+        * Available maneuvers are: wait, changelane, stop, keeplane, follow
+    * These results are visually evaluated.
 * High-level policy:
     * To train high-level policy from scratch using the given low-level policies: `python3 high_level_policy_main.py --train`
-    * To evaluate this trained high-level policy: `python3 high_level_policy_main.py --evaluate --saved_policy_in_root`
-* To run MCTS using the high-level policy: `python3 mcts.py`
+    * To evaluate this trained high-level policy: `python3 high_level_policy_main.py --evaluate --saved_policy_in_root`.
+    * The success average and standard deviation corresponds to the result from high-level policy experiments.
+* To run MCTS using the high-level policy:
+    * To obtain a probabilites tree and save it: `python3 mcts.py --train`
+    * To evaluate using this saved tree: `python3 mcts.py --evaluate --saved_policy_in_root`.
+    * The success average and standard deviation corresponds to the results from MCTS experiments.
 
 Coding Standards
 ================
diff --git a/high_level_policy_main.py b/high_level_policy_main.py
index a9b4108..e22ad07 100644
--- a/high_level_policy_main.py
+++ b/high_level_policy_main.py
@@ -155,7 +155,7 @@ if __name__ == "__main__":
         help="Number of steps to train for. Default is 25000", default=25000, type=int)
     parser.add_argument(
         "--nb_episodes_for_test",
-        help="Number of episodes to test/evaluate. Default is 20", default=20, type=int)
+        help="Number of episodes to test/evaluate. Default is 100", default=100, type=int)
     parser.add_argument(
         "--nb_trials",
         help="Number of trials to evaluate. Default is 10", default=10, type=int)
diff --git a/low_level_policy_main.py b/low_level_policy_main.py
index 291ea82..d4a129a 100644
--- a/low_level_policy_main.py
+++ b/low_level_policy_main.py
@@ -102,11 +102,11 @@ if __name__ == "__main__":
         help="the option to train. Eg. stop, keeplane, wait, changelane, follow. If not defined, trains all options")
     parser.add_argument(
         "--test",
-        help="Test a saved high level policy. Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
+        help="Test a saved high level policy. Uses saved policy in backends/trained_policies/OPTION_NAME/ by default",
         action="store_true")
     parser.add_argument(
         "--saved_policy_in_root",
-        help="Use saved policies in root of project rather than backends/trained_policies/highlevel/",
+        help="Use saved policies in root of project rather than backends/trained_policies/OPTION_NAME",
         action="store_true")
     parser.add_argument(
         "--load_weights",
@@ -141,15 +141,18 @@ if __name__ == "__main__":
                                       tensorboard=args.tensorboard)
         else:
             for option_key in options.maneuvers.keys():
+                print("Training {} maneuver...".format(option_key))
                 low_level_policy_training(option_key, load_weights=args.load_weights, nb_steps=args.nb_steps,
                                           nb_episodes_for_test=args.nb_episodes_for_test, visualize=args.visualize,
                                           tensorboard=args.tensorboard)
 
     if args.test:
         if args.option:
+            print("Testing {} maneuver...".format(args.option))
             low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
                                      nb_episodes_for_test=args.nb_episodes_for_test)
         else:
             for option_key in options.maneuvers.keys():
+                print("Testing {} maneuver...".format(option_key))
                 low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
                                          nb_episodes_for_test=args.nb_episodes_for_test)
\ No newline at end of file
diff --git a/mcts.py b/mcts.py
index 95ef93f..9bd2abb 100644
--- a/mcts.py
+++ b/mcts.py
@@ -1,11 +1,9 @@
 from env.simple_intersection import SimpleIntersectionEnv
-from env.simple_intersection.constants import *
 from options.options_loader import OptionsGraph
 from backends import DDPGLearner, DQNLearner, MCTSLearner
-import pickle
-import tqdm
 import numpy as np
 import tqdm
+import argparse
 
 import sys
 
@@ -28,7 +26,7 @@ class Logger(object):
 sys.stdout = Logger()
 
 # TODO: make a separate file for this function.
-def mcts_training(nb_traversals, save_every=20, visualize=False):
+def mcts_training(nb_traversals, save_every=20, visualize=False, load_saved=False, save_file="mcts.pickle"):
     """
     Do RL of the low-level policy of the given maneuver and test it.
     Args:
@@ -50,7 +48,9 @@ def mcts_training(nb_traversals, save_every=20, visualize=False):
     agent.load_model("backends/trained_policies/highlevel/highlevel_weights.h5f")
     options.set_controller_args(predictor = agent.get_softq_value_using_option_alias)
     options.controller.max_depth = 20
-    #options.controller.load_model('backends/trained_policies/mcts/mcts.pickle')
+
+    if load_saved:
+        options.controller.load_model(save_file)
 
     total_epochs = nb_traversals//save_every
     trav_num = 1
@@ -62,17 +62,16 @@ def mcts_training(nb_traversals, save_every=20, visualize=False):
             options.controller.curr_node_num = 0
             init_obs = options.reset()
             v, all_ep_R = options.controller.traverse(init_obs, visualize=visualize)
-            # print('Traversal %d: V = %f' % (num_traversal, v))
-            # print('Overall Reward: %f\n' % all_ep_R)
+
             last_rewards += [all_ep_R]
             trav_num += 1
-        options.controller.save_model('mcts_%d.pickle' % (num_epoch))
+        options.controller.save_model(save_file)
         success = lambda x: x > 50
         success_rate = np.sum(list(map(success, last_rewards)))/(len(last_rewards)*1.0)
         print('success rate: %f' % success_rate)
         print('Average Reward (%d-%d): %f\n' % (beg_trav_num, trav_num-1, np.mean(last_rewards)))
 
-def mcts_evaluation(nb_traversals, num_trials=5, visualize=False):
+def mcts_evaluation(nb_traversals, num_trials=5, visualize=False, save_file="mcts.pickle", pretrained=False):
     """
     Do RL of the low-level policy of the given maneuver and test it.
     Args:
@@ -95,11 +94,14 @@ def mcts_evaluation(nb_traversals, num_trials=5, visualize=False):
     options.set_controller_args(predictor=agent.get_softq_value_using_option_alias)
     options.controller.max_depth = 20
 
+    if pretrained:
+        save_file = "backends/trained_policies/mcts/" + save_file
+
     success_list = []
     print('Total number of trials = %d' % num_trials)
     for trial in range(num_trials):
         num_successes = 0
-        options.controller.load_model('backends/trained_policies/mcts/mcts.pickle')
+        options.controller.load_model(save_file)
         for num_traversal in tqdm.tqdm(range(nb_traversals)):
             options.controller.curr_node_num = 0
             init_obs = options.reset()
@@ -222,20 +224,51 @@ def evaluate_online_mcts(nb_episodes=20, nb_trials=5):
                                             np.mean(count_list),
                                             np.std(count_list)))
 
-def mcts_visualize(file_name):
-    with open(file_name, 'rb') as handle:
-        to_restore = pickle.load(handle)
-    # TR = to_restore['TR']
-    # M = to_restore['M']
-    # for key, val in TR.items():
-    #    print('%s: %f, count = %d' % (key, val/M[key], M[key]))
-    print(len(to_restore['nodes']))
-
 if __name__ == "__main__":
-
-    mcts_training(nb_traversals=10000, save_every=1000, visualize=False)
-    # mcts_evaluation(nb_traversals=100, num_trials=10, visualize=False)
-    # for num in range(100): mcts_visualize('timeout_inf_save100/mcts_%d.pickle' % num)
-    # mcts_visualize('mcts.pickle')
-    #online_mcts(10)
-    # evaluate_online_mcts(nb_episodes=20,nb_trials=5)
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--train",
+        help="Train an offline mcts with default settings. Always saved in root folder.",
+        action="store_true")
+    parser.add_argument(
+        "--evaluate",
+        help="Evaluate over n trials. "
+             "Uses backends/trained_policies/mcts/mcts.pickle by default",
+        action="store_true")
+    parser.add_argument(
+        "--saved_policy_in_root",
+        help="Use saved policies in root of project rather than backends/trained_policies/mcts/",
+        action="store_true")
+    parser.add_argument(
+        "--load_saved",
+        help="Load a saved policy from root folder first before training",
+        action="store_true")
+    parser.add_argument(
+        "--visualize",
+        help="Visualize the training. Testing is always visualized. Evaluation is not visualized by default",
+        action="store_true")
+    parser.add_argument(
+        "--nb_traversals",
+        help="Number of traversals to perform. Default is 1000", default=1000, type=int)
+    parser.add_argument(
+        "--save_every",
+        help="Saves every n traversals. Saves in root by default. Default is 500", default=500, type=int)
+    parser.add_argument(
+        "--nb_traversals_for_test",
+        help="Number of episodes to evaluate. Default is 100", default=100, type=int)
+    parser.add_argument(
+        "--nb_trials",
+        help="Number of trials to evaluate. Default is 10", default=10, type=int)
+    parser.add_argument(
+        "--save_file",
+        help="filename to save/load the trained policy. Location is as specified by --saved_policy_in_root. Default name is mcts.pickle",
+        default="mcts.pickle")
+
+    args = parser.parse_args()
+
+    if args.train:
+        mcts_training(nb_traversals=args.nb_traversals, save_every=args.save_every, visualize=args.visualize,
+                      load_saved=args.load_saved, save_file=args.save_file)
+    if args.evaluate:
+        mcts_evaluation(nb_traversals=args.nb_traversals_for_test, num_trials=args.nb_trials, visualize=args.visualize,
+                        pretrained=not args.saved_policy_in_root, save_file=args.save_file)
\ No newline at end of file
-- 
GitLab