options_loader.py 6.4 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
import json
from .simple_intersection.maneuvers import *
from .simple_intersection.mcts_maneuvers import *
from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSController, ManualPolicy


class OptionsGraph:
    """
    Represent the options graph as a graph like structure. The configuration
    is specified in a json file and consists of the following specific values:

    * nodes: dictionary of policy node aliases -> maneuver classes
    * edges: dictionary such that it has key value pairs which represent
        edges between policy node aliases
    * starting_node: ego's starting node
    """
    #: visualization flag
    visualize_low_level_steps = False

    def __init__(self, json_path, env_class, *env_args, **env_kwargs):
        """
        Constructor for the options graph.

        Args:
            json_path: path to the json file which stores the configuration
                for this options graph
            env_class: the gym compliant environment class
            env_args: arguments to pass when initializing the env
            env_kwargs: named arguments to pass when initializing the env
        """

        self.config = json.load(open(json_path))

        #: Gym compliant environment instance. This instance also has a
        #  back-reference to this options instance.
        self.env = env_class(*env_args, **env_kwargs)
        #: nodes of the graph
        self.nodes = self.config["nodes"]
        #: edges of the graph
        self.edges = self.config["edges"]
        #: adjacency list from the graph
        self.adj = {key: [] for key, value in self.nodes.items()}
        for i, j in self.edges.items():
            self.adj[i].append(j)

        #: Hold the different maneuver objects
        #  Each maneuver object also has a back-reference to this options
        #  instance. TODO: allow for passing arguments to maneuver
        #  objects at init, for example, noise
        # TODO: Change name maneuver as it is SimpleIntersectionEnv specific.
        self.maneuvers = {key: globals()[value](self.env) for key, value in self.nodes.items()}
        self.maneuvers_alias = list(self.maneuvers.keys()) # so that maneuver can be referenced using an int

        #: starting node
        #  TODO: reimplement if needed for multi-agent training?
        self.start_node_alias = self.config["start_node"]

        #: high level policy over low level policies
        #  possible classes : rl
        if self.config["method"] == "rl":
            self.controller = RLController(self.env, self.maneuvers, self.start_node_alias)
        elif self.config["method"] == "manual":
            self.controller = ManualPolicy(self.env, self.maneuvers, self.adj, self.start_node_alias)
        elif self.config["method"] == "mcts":
            self.controller = MCTSLearner(self.env, self.maneuvers, self.start_node_alias)
        elif self.config["method"] == "online_mcts":
            self.controller = OnlineMCTSController(self.env, self.maneuvers, self.start_node_alias)
        else:
            raise Exception(self.__class__.__name__ + \
                                 "Controller to be used not specified")

    def step(self, option):
        """Complete an episode using specified option. This assumes that
        the manager's env is at a place where the option's initiation
        condition is met.

        Args:
            option: index of high level option to be executed
        """

        #check if the incoming option is an integer or an option alias(string)
        try:
            option = int(option)
            option_alias = self.maneuvers_alias[option]
        except ValueError:
            option_alias = option

        self.controller.set_current_node(option_alias)

        # execute whole maneuver
        return self.controller.step_current_node(visualize_low_level_steps=self.visualize_low_level_steps)

    def reset(self):
        """Reset the environment. This function may be needed to reset the
        environment for eg. after an MCTS rollout and update. Also reset the
        controller to root node.

        Returns whatever the environment's reset returns.
        """

        self.controller.set_current_node(self.start_node_alias)
        return self.env.reset()

    def set_controller_policy(self, policy):
        """Sets the trained controller policy as a function which takes in feature vector and returns
        an option index(int). By default, trained_policy is None

        Args:
            policy: a function which takes in feature vector and returns an option index(int)
        """
        self.controller.set_trained_policy(policy)

    def set_controller_args(self, **kwargs):
        """Sets custom arguments depending on the chosen controller

        Args:
            **kwargs: dictionary of args and values
        """
        self.controller.set_controller_args(**kwargs)


    def execute_controller_policy(self):
        """Performs low level steps and transitions to other nodes using controller transition policy.

        Returns state after one step, step reward, episode_termination_flag, info
        """
        if self.controller.can_transition():
            self.controller.do_transition()

        return self.controller.low_level_step_current_node()

    def set_current_node(self, node_alias):
        """Sets the current node for controller. Used for training/testing a particular node

        Args:
            node_alias: alias of the node as per config file
        """
        self.controller.set_current_node(node_alias)

    @property
    def current_node(self):
        return self.controller.current_node

    # TODO: specify values using config file, else use these defaults
    # TODO: error handling
    def load_trained_low_level_policies(self):
        for key, maneuver in self.maneuvers.items():
            agent = DDPGLearner(input_shape=(maneuver.get_reduced_feature_length(),),
                                nb_actions=2, gamma=0.99,
                                nb_steps_warmup_critic=200,
                                nb_steps_warmup_actor=200,
                                lr=1e-3)
            agent.load_model("backends/trained_policies/" + key + "/" + key + "_weights.h5f")
            maneuver.set_low_level_trained_policy(agent.predict)

            if self.config["method"] == "mcts":
                maneuver.timeout = np.inf

    def get_number_of_nodes(self):
        return len(self.maneuvers)

    def render(self, mode='human'):
        self.env.render()