Commit 228db9b4 authored by Jae Young Lee's avatar Jae Young Lee

Merge remote-tracking branch 'origin/final_test' into final_test

parents 5c3b366b a8bc838a
......@@ -29,14 +29,24 @@ These are the minimum steps required to replicate the results for simple_interse
* Run `./scripts/install_dependencies.sh` to install python dependencies.
* Low-level policies:
* To train all low-level policies from scratch: `python3 low_level_policy_main.py --train`
* To train a single low-level, for example wait: `python3 low_level_policy_main.py --option=wait --train`
* To test these trained low-level policies: `python3 low_level_policy_main.py --test --saved_policy_in_root`
* To test one of these trained low-level policies, for example wait: `python3 low_level_policy_main.py --option=wait --test --saved_policy_in_root`
* You can choose to train and test all the maneuvers. But this may take some time and is not recommended.
* To train all low-level policies from scratch: `python3 low_level_policy_main.py --train`. This may take some time.
* To test all these trained low-level policies: `python3 low_level_policy_main.py --test --saved_policy_in_root`.
* Make sure the training is fully complete before running above test.
* It is easier to verify few of the maneuvers using below commands:
* To train a single low-level, for example wait: `python3 low_level_policy_main.py --option=wait --train`.
* To test one of these trained low-level policies, for example wait: `python3 low_level_policy_main.py --option=wait --test --saved_policy_in_root`
* Available maneuvers are: wait, changelane, stop, keeplane, follow
* These results are visually evaluated.
* Note: This training has a high variance due to the continuous action space, especially for stop and keeplane maneuvers. It may help to train for 0.2 million steps than the default 0.1 million by adding argument '--nb_steps=200000' while training.
* High-level policy:
* To train high-level policy from scratch using the given low-level policies: `python3 high_level_policy_main.py --train`
* To evaluate this trained high-level policy: `python3 high_level_policy_main.py --evaluate --saved_policy_in_root`
* To run MCTS using the high-level policy: `python3 mcts.py`
* To evaluate this trained high-level policy: `python3 high_level_policy_main.py --evaluate --saved_policy_in_root`.
* The success average and standard deviation corresponds to the result from high-level policy experiments.
* To run MCTS using the high-level policy:
* To obtain a probabilites tree and save it: `python3 mcts.py --train`
* To evaluate using this saved tree: `python3 mcts.py --evaluate --saved_policy_in_root`.
* The success average and standard deviation corresponds to the results from MCTS experiments.
Coding Standards
================
......
......@@ -155,7 +155,7 @@ if __name__ == "__main__":
help="Number of steps to train for. Default is 25000", default=25000, type=int)
parser.add_argument(
"--nb_episodes_for_test",
help="Number of episodes to test/evaluate. Default is 20", default=20, type=int)
help="Number of episodes to test/evaluate. Default is 100", default=100, type=int)
parser.add_argument(
"--nb_trials",
help="Number of trials to evaluate. Default is 10", default=10, type=int)
......
......@@ -102,11 +102,11 @@ if __name__ == "__main__":
help="the option to train. Eg. stop, keeplane, wait, changelane, follow. If not defined, trains all options")
parser.add_argument(
"--test",
help="Test a saved high level policy. Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
help="Test a saved high level policy. Uses saved policy in backends/trained_policies/OPTION_NAME/ by default",
action="store_true")
parser.add_argument(
"--saved_policy_in_root",
help="Use saved policies in root of project rather than backends/trained_policies/highlevel/",
help="Use saved policies in root of project rather than backends/trained_policies/OPTION_NAME",
action="store_true")
parser.add_argument(
"--load_weights",
......@@ -141,15 +141,18 @@ if __name__ == "__main__":
tensorboard=args.tensorboard)
else:
for option_key in options.maneuvers.keys():
print("Training {} maneuver...".format(option_key))
low_level_policy_training(option_key, load_weights=args.load_weights, nb_steps=args.nb_steps,
nb_episodes_for_test=args.nb_episodes_for_test, visualize=args.visualize,
tensorboard=args.tensorboard)
if args.test:
if args.option:
print("Testing {} maneuver...".format(args.option))
low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
nb_episodes_for_test=args.nb_episodes_for_test)
else:
for option_key in options.maneuvers.keys():
print("Testing {} maneuver...".format(option_key))
low_level_policy_testing(args.option, pretrained=not args.saved_policy_in_root,
nb_episodes_for_test=args.nb_episodes_for_test)
\ No newline at end of file
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends import DDPGLearner, DQNLearner, MCTSLearner
import pickle
import tqdm
import numpy as np
import tqdm
import argparse
import sys
......@@ -28,7 +26,7 @@ class Logger(object):
sys.stdout = Logger()
# TODO: make a separate file for this function.
def mcts_training(nb_traversals, save_every=20, visualize=False):
def mcts_training(nb_traversals, save_every=20, visualize=False, load_saved=False, save_file="mcts.pickle"):
"""
Do RL of the low-level policy of the given maneuver and test it.
Args:
......@@ -50,7 +48,9 @@ def mcts_training(nb_traversals, save_every=20, visualize=False):
agent.load_model("backends/trained_policies/highlevel/highlevel_weights.h5f")
options.set_controller_args(predictor = agent.get_softq_value_using_option_alias)
options.controller.max_depth = 20
#options.controller.load_model('backends/trained_policies/mcts/mcts.pickle')
if load_saved:
options.controller.load_model(save_file)
total_epochs = nb_traversals//save_every
trav_num = 1
......@@ -62,17 +62,16 @@ def mcts_training(nb_traversals, save_every=20, visualize=False):
options.controller.curr_node_num = 0
init_obs = options.reset()
v, all_ep_R = options.controller.traverse(init_obs, visualize=visualize)
# print('Traversal %d: V = %f' % (num_traversal, v))
# print('Overall Reward: %f\n' % all_ep_R)
last_rewards += [all_ep_R]
trav_num += 1
options.controller.save_model('mcts_%d.pickle' % (num_epoch))
options.controller.save_model(save_file)
success = lambda x: x > 50
success_rate = np.sum(list(map(success, last_rewards)))/(len(last_rewards)*1.0)
print('success rate: %f' % success_rate)
print('Average Reward (%d-%d): %f\n' % (beg_trav_num, trav_num-1, np.mean(last_rewards)))
def mcts_evaluation(nb_traversals, num_trials=5, visualize=False):
def mcts_evaluation(nb_traversals, num_trials=5, visualize=False, save_file="mcts.pickle", pretrained=False):
"""
Do RL of the low-level policy of the given maneuver and test it.
Args:
......@@ -95,11 +94,14 @@ def mcts_evaluation(nb_traversals, num_trials=5, visualize=False):
options.set_controller_args(predictor=agent.get_softq_value_using_option_alias)
options.controller.max_depth = 20
if pretrained:
save_file = "backends/trained_policies/mcts/" + save_file
success_list = []
print('Total number of trials = %d' % num_trials)
for trial in range(num_trials):
num_successes = 0
options.controller.load_model('backends/trained_policies/mcts/mcts.pickle')
options.controller.load_model(save_file)
for num_traversal in tqdm.tqdm(range(nb_traversals)):
options.controller.curr_node_num = 0
init_obs = options.reset()
......@@ -222,20 +224,51 @@ def evaluate_online_mcts(nb_episodes=20, nb_trials=5):
np.mean(count_list),
np.std(count_list)))
def mcts_visualize(file_name):
with open(file_name, 'rb') as handle:
to_restore = pickle.load(handle)
# TR = to_restore['TR']
# M = to_restore['M']
# for key, val in TR.items():
# print('%s: %f, count = %d' % (key, val/M[key], M[key]))
print(len(to_restore['nodes']))
if __name__ == "__main__":
mcts_training(nb_traversals=10000, save_every=1000, visualize=False)
# mcts_evaluation(nb_traversals=100, num_trials=10, visualize=False)
# for num in range(100): mcts_visualize('timeout_inf_save100/mcts_%d.pickle' % num)
# mcts_visualize('mcts.pickle')
#online_mcts(10)
# evaluate_online_mcts(nb_episodes=20,nb_trials=5)
parser = argparse.ArgumentParser()
parser.add_argument(
"--train",
help="Train an offline mcts with default settings. Always saved in root folder.",
action="store_true")
parser.add_argument(
"--evaluate",
help="Evaluate over n trials. "
"Uses backends/trained_policies/mcts/mcts.pickle by default",
action="store_true")
parser.add_argument(
"--saved_policy_in_root",
help="Use saved policies in root of project rather than backends/trained_policies/mcts/",
action="store_true")
parser.add_argument(
"--load_saved",
help="Load a saved policy from root folder first before training",
action="store_true")
parser.add_argument(
"--visualize",
help="Visualize the training. Testing is always visualized. Evaluation is not visualized by default",
action="store_true")
parser.add_argument(
"--nb_traversals",
help="Number of traversals to perform. Default is 1000", default=1000, type=int)
parser.add_argument(
"--save_every",
help="Saves every n traversals. Saves in root by default. Default is 500", default=500, type=int)
parser.add_argument(
"--nb_traversals_for_test",
help="Number of episodes to evaluate. Default is 100", default=100, type=int)
parser.add_argument(
"--nb_trials",
help="Number of trials to evaluate. Default is 10", default=10, type=int)
parser.add_argument(
"--save_file",
help="filename to save/load the trained policy. Location is as specified by --saved_policy_in_root. Default name is mcts.pickle",
default="mcts.pickle")
args = parser.parse_args()
if args.train:
mcts_training(nb_traversals=args.nb_traversals, save_every=args.save_every, visualize=args.visualize,
load_saved=args.load_saved, save_file=args.save_file)
if args.evaluate:
mcts_evaluation(nb_traversals=args.nb_traversals_for_test, num_trials=args.nb_trials, visualize=args.visualize,
pretrained=not args.saved_policy_in_root, save_file=args.save_file)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment