from .policy_base import PolicyBase import numpy as np class LearnerBase(PolicyBase): """The abstract class from which each learning policy backend is defined and inherited.""" def __init__(self, input_shape=(10, ), nb_actions=2, **kwargs): """The constructor which sets the properties of the class. Args: input_shape: Shape of observation space, e.g (10,); nb_actions: number of values in action space; **kwargs: other optional key-value arguments with defaults defined in property_defaults """ self.input_shape = input_shape self.nb_actions = nb_actions property_defaults = {"lr": 0.001, "gamma": 0.99} for (prop, default) in property_defaults.items(): setattr(self, prop, kwargs.get(prop, default)) def train(self, env, nb_steps=50000, visualize=False, nb_max_episode_steps=200): """Train the learning agent on the environment. Args: env: the environment instance. Should contain step() and reset() methods and optionally render() nb_steps: the total number of steps to train visualize: If True, visualizes the training. Works only if render() is present in env nb_max_episode_steps: Maximum number of steps per episode """ return # do nothing unless specified in the subclass def save_model(self, file_name, overwrite=True): """Save the weights of the agent. To be used after learning. Args: file_name: filename to be used when saving overwrite: If True, overwrites existing file """ return # do nothing unless specified in the subclass def load_model(self, file_name): """Load the weights of an agent. Args: file_name: filename to be used when loading """ return # do nothing unless specified in the subclass def test_model(self, env, nb_episodes=5, visualize=True, nb_max_episode_steps=200): """Test the agent on the environment. Args: env: the environment instance. Should contain step(), reset() and optionally, render() nb_episodes: Number of episodes to run visualize: If True, visualizes the test. Works only if render() is present in env nb_max_episode_steps: Maximum number of steps per episode """ return # do nothing unless specified in the subclass def predict(self, observation): """Perform a forward pass and return next action by agent based on current observation. Args: observation: the current observation. Shape should be same as self.input_shape Returns: The action taken by agent depending on given observation """ return # do nothing unless specified in the subclass