controller_base.py 3 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
from .policy_base import PolicyBase


class ControllerBase(PolicyBase):
    """Abstract class for controllers."""

    def __init__(self, env, low_level_policies, start_node_alias):
        self.env = env
        self.low_level_policies = low_level_policies

        # TODO: Move an intermediate class so that base class can be clean
        self.current_node = self.low_level_policies[start_node_alias]
        self.node_terminal_state_reached = False
        self.controller_args_defaults = {}

    def set_controller_args(self, **kwargs):
        for (prop, default) in self.controller_args_defaults.items():
            setattr(self, prop, kwargs.get(prop, default))

    def can_transition(self):
        """Returns boolean signifying whether we can transition. To be
        implemented in subclass.
        """

        raise NotImplemented(self.__class__.__name__ + \
            "can_transition is not implemented.")

    def do_transition(self, observation):
        """Do a transition, assuming we can transition. To be
        implemented in subclass.

        Args:
            observation: final observation from episodic step
        """

        raise NotImplemented(self.__class__.__name__ + \
            "do_transition is not implemented.")

    def set_current_node(self, node_alias):
        """Sets the current node which is being executed

        Args:
            node: node alias of the node to be set
        """
        raise NotImplemented(self.__class__.__name__ + \
            "set_current_node is not implemented.")

    # TODO: Looks generic. Move to an intermediate class/highlevel manager so that base class can be clean
    ''' Executes the current node until node termination condition is reached
    
    Returns state at end of node execution, total reward, epsiode_termination_flag, info
    '''
    def step_current_node(self, visualize_low_level_steps=False):
        total_reward = 0
        self.node_terminal_state_reached = False
        while not self.node_terminal_state_reached:
            observation, reward, terminal, info = self.low_level_step_current_node()
            if visualize_low_level_steps:
                self.env.render()
            total_reward += reward

        total_reward += self.current_node.high_level_extra_reward

        # TODO for info
        return observation, total_reward, self.env.termination_condition, info

    # TODO: Looks generic. Move to an intermediate class/highlevel manager so that base class can be clean
    ''' Executes one step of current node. Sets node_terminal_state_reached flag if node termination condition
    has been reached. 

    Returns state after one step, step reward, episode_termination_flag, info
    '''
    def low_level_step_current_node(self):

        u_ego = self.current_node.low_level_policy(self.current_node.get_reduced_features_tuple())
        feature, R, terminal, info = self.current_node.step(u_ego)
        self.node_terminal_state_reached = terminal
        return self.env.get_features_tuple(), R, self.env.termination_condition, info