from worlds.simple_intersection import SimpleIntersectionEnv
from worlds.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.kerasrl_learner import DQNLearner
import os
import argparse


# TODO: make a separate file for this function.
def high_level_policy_training(nb_steps=25000,
                               load_weights=False,
                               training=True,
                               testing=True,
                               nb_episodes_for_test=20,
                               max_nb_steps=100,
                               visualize=False,
                               tensorboard=False,
                               save_path="highlevel_weights.h5f"):
    """Do RL of the high-level policy and test it.

    Args:
     nb_steps: the number of steps to perform RL
     load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs)
     training: True to enable training
     testing: True to enable testing
     nb_episodes_for_test: the number of episodes for testing
    """
    # initialize the numpy random number generator
    np.random.seed()

    if not (isinstance(load_weights, bool) and isinstance(training, bool)
            and isinstance(testing, bool)):
        raise ValueError("Type error: the variable has to be boolean.")

    if not load_weights and not training:
        raise ValueError(
            "Both load_weights and training are False: no learning and no loading weights."
        )

    if not isinstance(nb_steps, int) or nb_steps <= 0:
        raise ValueError("nb_steps has to be a positive number.")

    global options

    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        target_model_update=1e-3,
        delta_clip=100,
        low_level_policies=options.maneuvers)
        #gamma=1)

    if load_weights:
        agent.load_model(save_path)

    if training:
        if visualize:
            options.visualize_low_level_steps = True
        agent.train(
            options,
            nb_steps=nb_steps,
            nb_max_episode_steps=max_nb_steps,
            tensorboard=tensorboard)
        agent.save_model(save_path)

    if testing:
        high_level_policy_testing(nb_episodes_for_test=nb_episodes_for_test)

    return agent


def high_level_policy_testing(nb_episodes_for_test=100,
                              trained_agent_file="highlevel_weights.h5f",
                              pretrained=False,
                              visualize=True):
    global options

    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        low_level_policies=options.maneuvers)

    if pretrained:
        trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file

    agent.load_model(trained_agent_file)
    options.set_controller_policy(agent.predict)

    agent.test_model(
        options, nb_episodes=nb_episodes_for_test, visualize=visualize)


def evaluate_high_level_policy(nb_episodes_for_test=100,
                               nb_trials=10,
                               trained_agent_file="highlevel_weights.h5f",
                               pretrained=False,
                               visualize=False):
    global options

    agent = DQNLearner(
        input_shape=(50, ),
        nb_actions=options.get_number_of_nodes(),
        low_level_policies=options.maneuvers)

    if pretrained:
        trained_agent_file = "backends/trained_policies/highlevel/" + trained_agent_file

    agent.load_model(trained_agent_file)
    options.set_controller_policy(agent.predict)

    success_list = []
    termination_reason_list = {}
    print("\nConducting {} trials of {} episodes each".format(
        nb_trials, nb_episodes_for_test))
    for trial in range(nb_trials):
        current_success, current_termination_reason = agent.test_model(
            options, nb_episodes=nb_episodes_for_test, visualize=visualize)
        print("\nTrial {}: success: {}".format(trial + 1, current_success))
        success_list.append(current_success)
        for reason, count in current_termination_reason.items():
            if reason in termination_reason_list:
                termination_reason_list[reason].append(count)
            else:
                termination_reason_list[reason] = [count]

    success_list = np.array(success_list)
    print("\nSuccess: Avg: {}, Std: {}".format(
        np.mean(success_list), np.std(success_list)))
    print("Termination reason(s):")
    for reason, count_list in termination_reason_list.items():
        count_list = np.array(count_list)
        while count_list.size != nb_trials:
            count_list = np.append(count_list,0)

        print("{}: Avg: {}, Std: {}".format(reason, np.mean(count_list),
                                            np.std(count_list)))


def find_good_high_level_policy(nb_steps=25000,
                                load_weights=False,
                                nb_episodes_for_test=100,
                                visualize=False,
                                tensorboard=False,
                                save_path="./highlevel_weights.h5f"):
    max_num_successes = 0
    current_success = 0
    while current_success < 0.95 * nb_episodes_for_test:
        agent = high_level_policy_training(
            nb_steps=nb_steps,
            load_weights=load_weights,
            visualize=visualize,
            tensorboard=tensorboard,
            testing=False)
        options.set_controller_policy(agent.predict)

        current_success, termination_reason_counter = agent.test_model(
            options, nb_episodes=nb_episodes_for_test, visualize=visualize)

        if current_success > max_num_successes:
            os.rename(save_path,
                      "highlevel_weights_{}.h5f".format(current_success))
            max_num_successes = current_success


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--train",
        help=
        "Train a high level policy with default settings. Always saved in root folder. Always tests after training",
        action="store_true")
    parser.add_argument(
        "--test",
        help=
        "Test a saved high level policy. Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
        action="store_true")
    parser.add_argument(
        "--evaluate",
        help="Evaluate a saved high level policy over n trials. "
        "Uses backends/trained_policies/highlevel/highlevel_weights.h5f by default",
        action="store_true")
    parser.add_argument(
        "--saved_policy_in_root",
        help=
        "Use saved policies in root of project rather than backends/trained_policies/highlevel/",
        action="store_true")
    parser.add_argument(
        "--load_weights",
        help="Load a saved policy from root folder first before training",
        action="store_true")
    parser.add_argument(
        "--tensorboard",
        help="Use tensorboard while training",
        action="store_true")
    parser.add_argument(
        "--visualize",
        help=
        "Visualize the training. Testing is always visualized. Evaluation is not visualized by default",
        action="store_true")
    parser.add_argument(
        "--nb_steps",
        help="Number of steps to train for. Default is 200000",
        default=200000,
        type=int)
    parser.add_argument(
        "--nb_episodes_for_test",
        help="Number of episodes to test/evaluate. Default is 100",
        default=100,
        type=int)
    parser.add_argument(
        "--nb_trials",
        help="Number of trials to evaluate. Default is 10",
        default=10,
        type=int)
    parser.add_argument(
        "--save_file",
        help=
        "filename to save/load the trained policy. Location is as specified by --saved_policy_in_root",
        default="highlevel_weights.h5f")

    args = parser.parse_args()

    # load options graph
    options = OptionsGraph(
        "config.json", SimpleIntersectionEnv, randomize_special_scenarios=True)
    options.load_trained_low_level_policies()

    if args.train:
        high_level_policy_training(
            nb_steps=args.nb_steps,
            load_weights=args.load_weights,
            save_path=args.save_file,
            tensorboard=args.tensorboard,
            visualize=args.visualize)

    if args.test:
        high_level_policy_testing(
            visualize=True,
            nb_episodes_for_test=args.nb_episodes_for_test,
            pretrained=not args.saved_policy_in_root,
            trained_agent_file=args.save_file)

    if args.evaluate:
        evaluate_high_level_policy(
            visualize=args.visualize,
            nb_episodes_for_test=args.nb_episodes_for_test,
            pretrained=not args.saved_policy_in_root,
            trained_agent_file=args.save_file,
            nb_trials=args.nb_trials)