Commit 382065ac authored by Unknown's avatar Unknown

Updated readme and running

parent d1dbbd44
......@@ -29,14 +29,23 @@ These are the minimum steps required to replicate the results for simple_interse
* Run `./scripts/install_dependencies.sh` to install python dependencies.
* Low-level policies:
* To train all low-level policies from scratch: `python3 low_level_policy_main.py --train`
* To train a single low-level, for example wait: `python3 low_level_policy_main.py --option=wait --train`
* To test these trained low-level policies: `python3 low_level_policy_main.py --test --saved_policy_in_root`
* To test one of these trained low-level policies, for example wait: `python3 low_level_policy_main.py --option=wait --test --saved_policy_in_root`
* You can choose to train and test all the maneuvers. But this may take some time and is not recommended.
* To train all low-level policies from scratch: `python3 low_level_policy_main.py --train`. This may take some time.
* To test all these trained low-level policies: `python3 low_level_policy_main.py --test --saved_policy_in_root`.
* Make sure the training is fully complete before running above test.
* It is easier to verify few of the maneuvers using below commands:
* To train a single low-level, for example wait: `python3 low_level_policy_main.py --option=wait --train`.
* To test one of these trained low-level policies, for example wait: `python3 low_level_policy_main.py --option=wait --test --saved_policy_in_root`
* Available maneuvers are: wait, changelane, stop, keeplane, follow
* These results are visually evaluated.
* High-level policy:
* To train high-level policy from scratch using the given low-level policies: `python3 high_level_policy_main.py --train`
* To evaluate this trained high-level policy: `python3 high_level_policy_main.py --evaluate --saved_policy_in_root`
* To run MCTS using the high-level policy: `python3 mcts.py`
* To evaluate this trained high-level policy: `python3 high_level_policy_main.py --evaluate --saved_policy_in_root`.
* The success average and standard deviation corresponds to the result from high-level policy experiments.
* To run MCTS using the high-level policy:
* To obtain a probabilites tree and save it: `python3 mcts.py --train`
* To evaluate using this saved tree: `python3 mcts.py --evaluate --saved_policy_in_root`.
* The success average and standard deviation corresponds to the results from MCTS experiments.
Coding Standards
================
......
......@@ -155,7 +155,7 @@ if __name__ == "__main__":
help="Number of steps to train for. Default is 25000", default=25000, type=int)
parser.add_argument(
"--nb_episodes_for_test",
help="Number of episodes to test/evaluate. Default is 20", default=20, type=int)
help="Number of episodes to test/evaluate. Default is 100", default=100, type=int)
parser.add_argument(
"--nb_trials",
help="Number of trials to evaluate. Default is 10", default=10, type=int)
......
......@@ -102,11 +102,11 @@ if __name__ == "__main__":
help="the option to train. Eg. stop, keeplane, wait, changelane, follow. If not defined, trains all options")
parser.add_argument(
"--test",
help="Test a saved high level policy. Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
help="Test a saved high level policy. Uses saved policy in backends/trained_policies/OPTION_NAME/ by default",
action="store_true")
parser.add_argument(
"--saved_policy_in_root",
help="Use saved policies in root of project rather than backends/trained_policies/highlevel/",
help="Use saved policies in root of project rather than backends/trained_policies/OPTION_NAME",
action="store_true")
parser.add_argument(
"--load_weights",
......@@ -141,15 +141,18 @@ if __name__ == "__main__":
tensorboard=args.tensorboard)
else:
for option_key in options.maneuvers.keys():
print("Training {} maneuver...".format(option_key))
low_level_policy_training(option_key, load_weights=args.load_weights, nb_steps=args.nb_steps,
nb_episodes_for_test=args.nb_episodes_for_test, visualize=args.visualize,
tensorboard=args.tensorboard)
if args.test:
if args.option:
print("Testing {} maneuver...".format(args.option))
low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
nb_episodes_for_test=args.nb_episodes_for_test)
else:
for option_key in options.maneuvers.keys():
print("Testing {} maneuver...".format(option_key))
low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
nb_episodes_for_test=args.nb_episodes_for_test)
\ No newline at end of file
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends import DDPGLearner, DQNLearner, MCTSLearner
import pickle
import tqdm
import numpy as np
import tqdm
import argparse
import sys
......@@ -28,7 +26,7 @@ class Logger(object):
sys.stdout = Logger()
# TODO: make a separate file for this function.
def mcts_training(nb_traversals, save_every=20, visualize=False):
def mcts_training(nb_traversals, save_every=20, visualize=False, load_saved=False, save_file="mcts.pickle"):
"""
Do RL of the low-level policy of the given maneuver and test it.
Args:
......@@ -50,7 +48,9 @@ def mcts_training(nb_traversals, save_every=20, visualize=False):
agent.load_model("backends/trained_policies/highlevel/highlevel_weights.h5f")
options.set_controller_args(predictor = agent.get_softq_value_using_option_alias)
options.controller.max_depth = 20
#options.controller.load_model('backends/trained_policies/mcts/mcts.pickle')
if load_saved:
options.controller.load_model(save_file)
total_epochs = nb_traversals//save_every
trav_num = 1
......@@ -62,17 +62,16 @@ def mcts_training(nb_traversals, save_every=20, visualize=False):
options.controller.curr_node_num = 0
init_obs = options.reset()
v, all_ep_R = options.controller.traverse(init_obs, visualize=visualize)
# print('Traversal %d: V = %f' % (num_traversal, v))
# print('Overall Reward: %f\n' % all_ep_R)
last_rewards += [all_ep_R]
trav_num += 1
options.controller.save_model('mcts_%d.pickle' % (num_epoch))
options.controller.save_model(save_file)
success = lambda x: x > 50
success_rate = np.sum(list(map(success, last_rewards)))/(len(last_rewards)*1.0)
print('success rate: %f' % success_rate)
print('Average Reward (%d-%d): %f\n' % (beg_trav_num, trav_num-1, np.mean(last_rewards)))
def mcts_evaluation(nb_traversals, num_trials=5, visualize=False):
def mcts_evaluation(nb_traversals, num_trials=5, visualize=False, save_file="mcts.pickle", pretrained=False):
"""
Do RL of the low-level policy of the given maneuver and test it.
Args:
......@@ -95,11 +94,14 @@ def mcts_evaluation(nb_traversals, num_trials=5, visualize=False):
options.set_controller_args(predictor=agent.get_softq_value_using_option_alias)
options.controller.max_depth = 20
if pretrained:
save_file = "backends/trained_policies/mcts/" + save_file
success_list = []
print('Total number of trials = %d' % num_trials)
for trial in range(num_trials):
num_successes = 0
options.controller.load_model('backends/trained_policies/mcts/mcts.pickle')
options.controller.load_model(save_file)
for num_traversal in tqdm.tqdm(range(nb_traversals)):
options.controller.curr_node_num = 0
init_obs = options.reset()
......@@ -222,20 +224,51 @@ def evaluate_online_mcts(nb_episodes=20, nb_trials=5):
np.mean(count_list),
np.std(count_list)))
def mcts_visualize(file_name):
with open(file_name, 'rb') as handle:
to_restore = pickle.load(handle)
# TR = to_restore['TR']
# M = to_restore['M']
# for key, val in TR.items():
# print('%s: %f, count = %d' % (key, val/M[key], M[key]))
print(len(to_restore['nodes']))
if __name__ == "__main__":
mcts_training(nb_traversals=10000, save_every=1000, visualize=False)
# mcts_evaluation(nb_traversals=100, num_trials=10, visualize=False)
# for num in range(100): mcts_visualize('timeout_inf_save100/mcts_%d.pickle' % num)
# mcts_visualize('mcts.pickle')
#online_mcts(10)
# evaluate_online_mcts(nb_episodes=20,nb_trials=5)
parser = argparse.ArgumentParser()
parser.add_argument(
"--train",
help="Train an offline mcts with default settings. Always saved in root folder.",
action="store_true")
parser.add_argument(
"--evaluate",
help="Evaluate over n trials. "
"Uses backends/trained_policies/mcts/mcts.pickle by default",
action="store_true")
parser.add_argument(
"--saved_policy_in_root",
help="Use saved policies in root of project rather than backends/trained_policies/mcts/",
action="store_true")
parser.add_argument(
"--load_saved",
help="Load a saved policy from root folder first before training",
action="store_true")
parser.add_argument(
"--visualize",
help="Visualize the training. Testing is always visualized. Evaluation is not visualized by default",
action="store_true")
parser.add_argument(
"--nb_traversals",
help="Number of traversals to perform. Default is 1000", default=1000, type=int)
parser.add_argument(
"--save_every",
help="Saves every n traversals. Saves in root by default. Default is 500", default=500, type=int)
parser.add_argument(
"--nb_traversals_for_test",
help="Number of episodes to evaluate. Default is 100", default=100, type=int)
parser.add_argument(
"--nb_trials",
help="Number of trials to evaluate. Default is 10", default=10, type=int)
parser.add_argument(
"--save_file",
help="filename to save/load the trained policy. Location is as specified by --saved_policy_in_root. Default name is mcts.pickle",
default="mcts.pickle")
args = parser.parse_args()
if args.train:
mcts_training(nb_traversals=args.nb_traversals, save_every=args.save_every, visualize=args.visualize,
load_saved=args.load_saved, save_file=args.save_file)
if args.evaluate:
mcts_evaluation(nb_traversals=args.nb_traversals_for_test, num_trials=args.nb_trials, visualize=args.visualize,
pretrained=not args.saved_policy_in_root, save_file=args.save_file)
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment