controller_base.py 3.17 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
from .policy_base import PolicyBase


class ControllerBase(PolicyBase):
    """Abstract class for controllers."""

    def __init__(self, env, low_level_policies, start_node_alias):
        self.env = env
        self.low_level_policies = low_level_policies

        # TODO: Move an intermediate class so that base class can be clean
        self.current_node = self.low_level_policies[start_node_alias]
        self.node_terminal_state_reached = False
        self.controller_args_defaults = {}

    def set_controller_args(self, **kwargs):
        for (prop, default) in self.controller_args_defaults.items():
            setattr(self, prop, kwargs.get(prop, default))

    def can_transition(self):
Ashish Gaurav's avatar
Ashish Gaurav committed
21 22 23
        """Returns boolean signifying whether we can transition.

        To be implemented in subclass.
Aravind Bk's avatar
Aravind Bk committed
24 25 26 27 28 29
        """

        raise NotImplemented(self.__class__.__name__ + \
            "can_transition is not implemented.")

    def do_transition(self, observation):
Ashish Gaurav's avatar
Ashish Gaurav committed
30 31
        """Do a transition, assuming we can transition. To be implemented in
        subclass.
Aravind Bk's avatar
Aravind Bk committed
32 33 34 35 36 37 38 39 40

        Args:
            observation: final observation from episodic step
        """

        raise NotImplemented(self.__class__.__name__ + \
            "do_transition is not implemented.")

    def set_current_node(self, node_alias):
Ashish Gaurav's avatar
Ashish Gaurav committed
41
        """Sets the current node which is being executed.
Aravind Bk's avatar
Aravind Bk committed
42 43 44 45 46 47 48 49 50 51 52 53

        Args:
            node: node alias of the node to be set
        """
        raise NotImplemented(self.__class__.__name__ + \
            "set_current_node is not implemented.")

    # TODO: Looks generic. Move to an intermediate class/highlevel manager so that base class can be clean
    ''' Executes the current node until node termination condition is reached
    
    Returns state at end of node execution, total reward, epsiode_termination_flag, info
    '''
54
    # TODO: this is never called when you test high-level policy rather than train...
Aravind Bk's avatar
Aravind Bk committed
55 56 57 58
    def step_current_node(self, visualize_low_level_steps=False):
        total_reward = 0
        self.node_terminal_state_reached = False
        while not self.node_terminal_state_reached:
59
            observation, reward, terminal, info = self.low_level_step_current_node()
Aravind Bk's avatar
Aravind Bk committed
60 61
            if visualize_low_level_steps:
                self.env.render()
62
            # TODO: make the total_reward discounted....
Aravind Bk's avatar
Aravind Bk committed
63 64 65 66 67 68 69 70 71 72 73 74 75
            total_reward += reward

        total_reward += self.current_node.high_level_extra_reward

        # TODO for info
        return observation, total_reward, self.env.termination_condition, info

    # TODO: Looks generic. Move to an intermediate class/highlevel manager so that base class can be clean
    ''' Executes one step of current node. Sets node_terminal_state_reached flag if node termination condition
    has been reached. 

    Returns state after one step, step reward, episode_termination_flag, info
    '''
Ashish Gaurav's avatar
Ashish Gaurav committed
76

Aravind Bk's avatar
Aravind Bk committed
77 78
    def low_level_step_current_node(self):

Ashish Gaurav's avatar
Ashish Gaurav committed
79 80
        u_ego = self.current_node.low_level_policy(
            self.current_node.get_reduced_features_tuple())
Aravind Bk's avatar
Aravind Bk committed
81 82
        feature, R, terminal, info = self.current_node.step(u_ego)
        self.node_terminal_state_reached = terminal
Ashish Gaurav's avatar
Ashish Gaurav committed
83 84
        return self.env.get_features_tuple(
        ), R, self.env.termination_condition, info