controller_base.py 3.26 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
from .policy_base import PolicyBase


class ControllerBase(PolicyBase):
    """Abstract class for controllers."""

    def __init__(self, env, low_level_policies, start_node_alias):
        self.env = env
        self.low_level_policies = low_level_policies

        # TODO: Move an intermediate class so that base class can be clean
        self.current_node = self.low_level_policies[start_node_alias]
        self.node_terminal_state_reached = False
        self.controller_args_defaults = {}

    def set_controller_args(self, **kwargs):
        for (prop, default) in self.controller_args_defaults.items():
            setattr(self, prop, kwargs.get(prop, default))

    def can_transition(self):
Ashish Gaurav's avatar
Ashish Gaurav committed
21 22 23
        """Returns boolean signifying whether we can transition.

        To be implemented in subclass.
Aravind Bk's avatar
Aravind Bk committed
24 25 26 27 28 29
        """

        raise NotImplemented(self.__class__.__name__ + \
            "can_transition is not implemented.")

    def do_transition(self, observation):
Ashish Gaurav's avatar
Ashish Gaurav committed
30 31
        """Do a transition, assuming we can transition. To be implemented in
        subclass.
Aravind Bk's avatar
Aravind Bk committed
32 33 34 35 36 37 38 39 40

        Args:
            observation: final observation from episodic step
        """

        raise NotImplemented(self.__class__.__name__ + \
            "do_transition is not implemented.")

    def set_current_node(self, node_alias):
Ashish Gaurav's avatar
Ashish Gaurav committed
41
        """Sets the current node which is being executed.
Aravind Bk's avatar
Aravind Bk committed
42 43 44 45 46 47 48 49 50 51 52 53

        Args:
            node: node alias of the node to be set
        """
        raise NotImplemented(self.__class__.__name__ + \
            "set_current_node is not implemented.")

    # TODO: Looks generic. Move to an intermediate class/highlevel manager so that base class can be clean
    ''' Executes the current node until node termination condition is reached
    
    Returns state at end of node execution, total reward, epsiode_termination_flag, info
    '''
54
    # TODO: this is never called when you TEST high-level policy (w/o MCTS) rather than train...
55 56
    # (make some integrated interface btw testing and training and b.t.w. the high- and low-level
    # methods with and without MCTS.
Aravind Bk's avatar
Aravind Bk committed
57 58 59 60
    def step_current_node(self, visualize_low_level_steps=False):
        total_reward = 0
        self.node_terminal_state_reached = False
        while not self.node_terminal_state_reached:
61
            observation, reward, terminal, info = self.low_level_step_current_node()
Aravind Bk's avatar
Aravind Bk committed
62 63
            if visualize_low_level_steps:
                self.env.render()
64
            # TODO: make the total_reward discounted....
Aravind Bk's avatar
Aravind Bk committed
65 66 67 68 69
            total_reward += reward

        total_reward += self.current_node.high_level_extra_reward

        # TODO for info
70
        return observation, total_reward, terminal, info
Aravind Bk's avatar
Aravind Bk committed
71 72 73 74 75 76 77

    # TODO: Looks generic. Move to an intermediate class/highlevel manager so that base class can be clean
    ''' Executes one step of current node. Sets node_terminal_state_reached flag if node termination condition
    has been reached. 

    Returns state after one step, step reward, episode_termination_flag, info
    '''
Ashish Gaurav's avatar
Ashish Gaurav committed
78

Aravind Bk's avatar
Aravind Bk committed
79 80
    def low_level_step_current_node(self):

81
        u_ego = self.current_node.low_level_policy(self.current_node.get_reduced_features_tuple())
Aravind Bk's avatar
Aravind Bk committed
82 83
        feature, R, terminal, info = self.current_node.step(u_ego)
        self.node_terminal_state_reached = terminal
84
        return self.env.get_features_tuple(), R, self.env.termination_condition, info