learner_base.py 2.94 KB
Newer Older
Ashish Gaurav's avatar
Ashish Gaurav committed
1
from .policy_base import PolicyBase
Aravind Bk's avatar
Aravind Bk committed
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
import numpy as np


class LearnerBase(PolicyBase):
    """The abstract class from which each learning policy backend is defined
    and inherited."""

    def __init__(self, input_shape=(10, ), nb_actions=2, **kwargs):
        """The constructor which sets the properties of the class.

        Args:
            input_shape: Shape of observation space, e.g (10,);
            nb_actions: number of values in action space;
            **kwargs: other optional key-value arguments with defaults defined in property_defaults
        """
        self.input_shape = input_shape
        self.nb_actions = nb_actions

        property_defaults = {"lr": 0.001, "gamma": 0.99}

        for (prop, default) in property_defaults.items():
            setattr(self, prop, kwargs.get(prop, default))

    def train(self,
Ashish Gaurav's avatar
Ashish Gaurav committed
26 27 28 29
              env,
              nb_steps=50000,
              visualize=False,
              nb_max_episode_steps=200):
Aravind Bk's avatar
Aravind Bk committed
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        """Train the learning agent on the environment.

        Args:
            env: the environment instance. Should contain step() and reset() methods and optionally render()
            nb_steps: the total number of steps to train
            visualize: If True, visualizes the training. Works only if render() is present in env
            nb_max_episode_steps: Maximum number of steps per episode
        """
        return  # do nothing unless specified in the subclass

    def save_model(self, file_name, overwrite=True):
        """Save the weights of the agent. To be used after learning.

        Args:
            file_name: filename to be used when saving
            overwrite: If True, overwrites existing file
        """
        return  # do nothing unless specified in the subclass

    def load_model(self, file_name):
        """Load the weights of an agent.

        Args:
            file_name: filename to be used when loading
        """
        return  # do nothing unless specified in the subclass

    def test_model(self,
                   env,
                   nb_episodes=5,
                   visualize=True,
                   nb_max_episode_steps=200):
        """Test the agent on the environment.

        Args:
            env: the environment instance. Should contain step(), reset() and optionally, render()
            nb_episodes: Number of episodes to run
            visualize: If True, visualizes the test. Works only if render() is present in env
            nb_max_episode_steps: Maximum number of steps per episode
        """
        return  # do nothing unless specified in the subclass

    def predict(self, observation):
        """Perform a forward pass and return next action by agent based on
        current observation.

        Args:
            observation: the current observation. Shape should be same as self.input_shape
        Returns:     The action taken by agent depending on given observation
        """
        return  # do nothing unless specified in the subclass