Commit f428cd41 authored by Ashish Gaurav's avatar Ashish Gaurav

run docformatter

parent 72e44d55
......@@ -48,10 +48,9 @@ class PPO2Agent(LearnerBase):
return MlpPolicy
def create_agent(self, policy, tensorboard):
"""Creates a PPO agent
"""Creates a PPO agent.
Returns:
stable_baselines PPO2 object
Returns: stable_baselines PPO2 object
"""
if tensorboard:
return PPO2(
......
......@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase):
setattr(self, prop, kwargs.get(prop, default))
def can_transition(self):
"""Returns boolean signifying whether we can transition. To be
implemented in subclass.
"""Returns boolean signifying whether we can transition.
To be implemented in subclass.
"""
raise NotImplemented(self.__class__.__name__ + \
"can_transition is not implemented.")
def do_transition(self, observation):
"""Do a transition, assuming we can transition. To be
implemented in subclass.
"""Do a transition, assuming we can transition. To be implemented in
subclass.
Args:
observation: final observation from episodic step
......@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase):
"do_transition is not implemented.")
def set_current_node(self, node_alias):
"""Sets the current node which is being executed
"""Sets the current node which is being executed.
Args:
node: node alias of the node to be set
......
......@@ -431,8 +431,8 @@ class DQNAgentOverOptions(DQNAgent):
self.low_level_policies.keys())
def __get_invalid_node_indices(self):
"""Returns a list of option indices that are invalid according to initiation conditions.
"""
"""Returns a list of option indices that are invalid according to
initiation conditions."""
invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases):
self.low_level_policies[option_alias].reset_maneuver()
......
......@@ -20,8 +20,8 @@ class ManualPolicy(ControllerBase):
self.adj = transition_adj
def _transition(self):
"""Check if the current node's termination condition is met and if
it is possible to transition to another node, i.e. its initiation
"""Check if the current node's termination condition is met and if it
is possible to transition to another node, i.e. its initiation
condition is met. This is an internal function.
Returns the new node if a transition can happen, None otherwise
......
......@@ -4,9 +4,8 @@ import pickle
class MCTSLearner(ControllerBase):
"""Monte Carlo Tree Search implementation using the UCB1 and
progressive widening approach as explained in Paxton et al (2017).
"""
"""Monte Carlo Tree Search implementation using the UCB1 and progressive
widening approach as explained in Paxton et al (2017)."""
_ucb_vals = set()
......@@ -76,8 +75,8 @@ class MCTSLearner(ControllerBase):
self.new_node_num = to_restore['new_node_num']
def _create_node(self, low_level_policy):
"""Create the node associated with curr_node_num, using the
given low level policy.
"""Create the node associated with curr_node_num, using the given low
level policy.
Args:
low_level_policy: the option's alias
......@@ -92,11 +91,10 @@ class MCTSLearner(ControllerBase):
return created_node_num, {"policy": low_level_policy}
def _to_discrete(self, observation):
"""Converts observation to a discrete observation tuple. Also
append (a) whether we are following a vehicle, and (b) whether
there is a vehicle in the opposite lane in the approximately
the same x position. These values will be useful for Follow
and ChangeLane maneuvers.
"""Converts observation to a discrete observation tuple. Also append
(a) whether we are following a vehicle, and (b) whether there is a
vehicle in the opposite lane in the approximately the same x position.
These values will be useful for Follow and ChangeLane maneuvers.
Args:
observation: observation tuple from the environment
......@@ -137,9 +135,9 @@ class MCTSLearner(ControllerBase):
def _get_visitation_count(self, observation, option=None):
"""Finds the visitation count of the discrete form of the observation.
If discrete observation not found, then inserted into self.N with
value 0. Auto converts the observation into discrete form. If option
is not None, then this uses self.M instead of self.N
If discrete observation not found, then inserted into self.N with value
0. Auto converts the observation into discrete form. If option is not
None, then this uses self.M instead of self.N.
Args:
observation: observation tuple from the environment
......@@ -177,8 +175,9 @@ class MCTSLearner(ControllerBase):
def _get_possible_options(self):
"""Returns a set of options that can be taken from the current node.
Goes through adjacency set of current node and finds which next nodes'
initiation condition is met.
Goes through adjacency set of current node and finds which next
nodes' initiation condition is met.
"""
all_options = set(self.low_level_policies.keys())
......@@ -211,9 +210,9 @@ class MCTSLearner(ControllerBase):
return visited_aliases
def _ucb_adjusted_q(self, observation, C=1):
"""Computes Q_star(observation, option_i) plus the UCB term, which
is C*[predictor(observation, option_i)]/[1+N(observation, option_i)],
for all option_i in the adjacency set of the current node.
"""Computes Q_star(observation, option_i) plus the UCB term, which is
C*[predictor(observation, option_i)]/[1+N(observation, option_i)], for
all option_i in the adjacency set of the current node.
Args:
observation: observation tuple from the environment
......@@ -257,14 +256,14 @@ class MCTSLearner(ControllerBase):
return sum_rewards / (1 + self._get_visitation_count(observation))
def _select(self, observation, depth=0, visualize=False):
"""MCTS selection function. For representation, we only use
the discrete part of the observation.
"""MCTS selection function. For representation, we only use the
discrete part of the observation.
Args:
observation: observation tuple from the environment
depth: current depth, starts from root node, hence 0 by default
visualize: whether or not to visualize low level steps
Returns the sum of values from the given observation.
"""
......@@ -316,8 +315,8 @@ class MCTSLearner(ControllerBase):
return self._value(observation), all_ep_R + episode_R
def traverse(self, observation, visualize=False):
"""Do a complete traversal from root to leaf. Assumes the
environment is reset and we are at the root node.
"""Do a complete traversal from root to leaf. Assumes the environment
is reset and we are at the root node.
Args:
observation: observation from the environment
......@@ -329,13 +328,13 @@ class MCTSLearner(ControllerBase):
return self._select(observation, visualize=visualize)
def do_transition(self, observation, visualize=False):
"""Do a transition using UCB metric, with the latest observation
from the episodic step.
"""Do a transition using UCB metric, with the latest observation from
the episodic step.
Args:
observation: final observation from episodic step
visualize: whether or not to visualize low level steps
visualize: whether or not to visualize low level steps
Returns o_star using UCB metric
"""
......
......@@ -5,7 +5,7 @@ import numpy as np
class OnlineMCTSController(ControllerBase):
"""Online MCTS"""
"""Online MCTS."""
def __init__(self, env, low_level_policies, start_node_alias):
"""Constructor for manual policy execution.
......
class PolicyBase:
"""Abstract policy base from which every policy backend is defined
and inherited."""
"""Abstract policy base from which every policy backend is defined and
inherited."""
......@@ -3,23 +3,20 @@ from model_checker import Parser
class GymCompliantEnvBase:
def step(self, action):
""" Gym compliant step function which
will be implemented in the subclass.
"""
"""Gym compliant step function which will be implemented in the
subclass."""
raise NotImplemented(self.__class__.__name__ +
"step is not implemented.")
def reset(self):
""" Gym compliant reset function which
will be implemented in the subclass.
"""
"""Gym compliant reset function which will be implemented in the
subclass."""
raise NotImplemented(self.__class__.__name__ +
"reset is not implemented.")
def render(self):
""" Gym compliant step function which
will be implemented in the subclass.
"""
"""Gym compliant step function which will be implemented in the
subclass."""
raise NotImplemented(self.__class__.__name__ +
"render is not implemented.")
......@@ -49,7 +46,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
def _init_LTL_preconditions(self):
"""Initialize the LTL preconditions (self._LTL_preconditions)..
in the subclass.
in the subclass.
"""
return
......@@ -72,9 +70,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
self._r_terminal = None
def _terminal_reward_superposition(self, r_obs):
"""Calculate the next value when observing "obs," from the prior
using min, max, or summation depending on the superposition_type.
"""
"""Calculate the next value when observing "obs," from the prior using
min, max, or summation depending on the superposition_type."""
if r_obs is None:
return
......@@ -133,7 +130,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return violate, info
def current_model_checking_result(self):
"""Returns whether or not any of the conditions is currently violated."""
"""Returns whether or not any of the conditions is currently
violated."""
for LTL_precondition in self._LTL_preconditions:
if LTL_precondition.result == Parser.FALSE:
......@@ -155,9 +153,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
@property
def termination_condition(self):
"""In the subclass, specify the condition for termination of the episode
(or the maneuver).
"""
"""In the subclass, specify the condition for termination of the
episode (or the maneuver)."""
if self._terminate_in_goal and self.goal_achieved:
return True
......
class RoadEnv:
"""The generic road env
"""The generic road env.
TODO: Implement this generic road env for plugging-in other road envs.
TODO: roadEnv also having a step() function can cause a problem.
TODO: Implement this generic road env for plugging-in other road envs.
TODO: roadEnv also having a step() function can cause a problem.
"""
\ No newline at end of file
......@@ -165,9 +165,7 @@ class Features(object):
return self.other_vehs[index]
def get_features_tuple(self):
"""
continuous + discrete features
"""
"""continuous + discrete features."""
feature = self.con_ego + self.dis_ego
for other_veh in self.other_vehs:
......
......@@ -14,7 +14,7 @@ class Rectangle(Shape):
"""Rectangle using OpenGL."""
def __init__(self, xmin, xmax, ymin, ymax, color=(0, 0, 0, 255)):
""" Constructor for Rectangle.
"""Constructor for Rectangle.
The rectangle has the four points (xmin, ymin), (xmin, ymax),
(xmax, ymax), (xmax, ymin).
......
......@@ -150,7 +150,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return self.vehs[EGO_INDEX]
def generate_scenario(self, **kwargs):
"""Randomly generate a road scenario with
"""Randomly generate a road scenario with.
"the N-number of vehicles + an ego vehicle"
......@@ -609,9 +609,9 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
@staticmethod
def __regulate_veh_speed(veh, dist, v_max_multiplier):
"""This private method set and regulate the speed of the
vehicle in a way that it never requires to travel the given
distance for complete stop.
"""This private method set and regulate the speed of the vehicle in a
way that it never requires to travel the given distance for complete
stop.
Args:
veh: the vehicle reference.
......@@ -920,9 +920,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
self.ego.theta) > self.max_ego_theta
def _check_collisions(self):
"""returns True when collision happens, False if not.
"""
"""returns True when collision happens, False if not."""
self.__ego_collision_happened = self.check_ego_collision() or \
self.is_ego_off_road()
self.__others_collision_happened = self.check_other_veh_collisions()
......@@ -1004,8 +1002,8 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return True if (h_offroad and v_offroad) else False
def cost(self, u):
"""Calculate the driving cost of the ego, i.e.,
the negative reward for the ego-driving.
"""Calculate the driving cost of the ego, i.e., the negative reward for
the ego-driving.
Args:
u: the low-level input (a, dot_psi) to the ego vehicle.
......@@ -1045,8 +1043,9 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
#TODO: Move this to utilities
def normalize_tuple(self, vec, scale_factor=10):
"""Normalizes each element in a tuple according to ranges defined in self.cost_normalization_ranges.
Normalizes between 0 and 1. And the scales by scale_factor
"""Normalizes each element in a tuple according to ranges defined in
self.cost_normalization_ranges. Normalizes between 0 and 1. And the
scales by scale_factor.
Args:
vec: The tuple to be normalized
......@@ -1070,8 +1069,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return tuple(normalized_vec)
def get_features_tuple(self):
"""
Get/calculate the features wrt. the current state variables
"""Get/calculate the features wrt. the current state variables.
Returns features tuple
"""
......@@ -1173,11 +1171,12 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
@property
def goal_achieved(self):
"""
A property from the base class which is True if the goal
of the road scenario is achieved, otherwise False. This property is
used in both step of EpisodicEnvBase and the implementation of
the high-level reinforcement learning and execution.
"""A property from the base class which is True if the goal of the road
scenario is achieved, otherwise False.
This property is used in both step of EpisodicEnvBase and the
implementation of the high-level reinforcement learning and
execution.
"""
return (self.ego.x >= rd.hlanes.end_pos) and \
......@@ -1220,9 +1219,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
label.draw()
def render(self):
"""
Gym compliant render function.
"""
"""Gym compliant render function."""
if self.window is None:
self.window = config_Pyglet(pyglet, NUM_HPIXELS, NUM_VPIXELS)
......
......@@ -247,7 +247,8 @@ def draw_all_shapes(shapes):
def calculate_v_max(dist):
"""Calculate the maximum velocity you can reach at the given position ahead.
"""Calculate the maximum velocity you can reach at the given position
ahead.
Args:
dist: the distance you travel from the current position.
......@@ -259,8 +260,8 @@ def calculate_v_max(dist):
def calculate_s(v):
"""Calculate the distance traveling when the vehicle is maximally
de-accelerating for a complete stop.
"""Calculate the distance traveling when the vehicle is maximally de-
accelerating for a complete stop.
Args:
v: the current speed of the vehicle.
......
......@@ -16,8 +16,8 @@ def high_level_policy_training(nb_steps=25000,
visualize=False,
tensorboard=False,
save_path="highlevel_weights.h5f"):
"""
Do RL of the high-level policy and test it.
"""Do RL of the high-level policy and test it.
Args:
nb_steps: the number of steps to perform RL
load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs)
......
......@@ -16,8 +16,8 @@ def low_level_policy_training(maneuver,
visualize=False,
nb_episodes_for_test=10,
tensorboard=False):
"""
Do RL of the low-level policy of the given maneuver and test it.
"""Do RL of the low-level policy of the given maneuver and test it.
Args:
maneuver: the name of the maneuver defined in config.json (e.g., 'default').
nb_steps: the number of steps to perform RL.
......
......@@ -34,8 +34,8 @@ def mcts_training(nb_traversals,
visualize=False,
load_saved=False,
save_file="mcts.pickle"):
"""
Do RL of the low-level policy of the given maneuver and test it.
"""Do RL of the low-level policy of the given maneuver and test it.
Args:
nb_traversals: number of MCTS traversals
save_every: save at every these many traversals
......@@ -91,8 +91,8 @@ def mcts_evaluation(nb_traversals,
visualize=False,
save_file="mcts.pickle",
pretrained=False):
"""
Do RL of the low-level policy of the given maneuver and test it.
"""Do RL of the low-level policy of the given maneuver and test it.
Args:
nb_traversals: number of MCTS traversals
save_every: save at every these many traversals
......@@ -170,8 +170,8 @@ def online_mcts(nb_episodes=10):
first_time = False
else:
print('Stepping through ...')
features, R, terminal, info = options.controller.step_current_node(
visualize_low_level_steps=True)
features, R, terminal, info = options.controller.\
step_current_node(visualize_low_level_steps=True)
episode_reward += R
print('Intermediate Reward: %f (ego x = %f)' %
(R, options.env.vehs[0].x))
......@@ -224,8 +224,8 @@ def evaluate_online_mcts(nb_episodes=20, nb_trials=5):
first_time = False
else:
print('Stepping through ...')
features, R, terminal, info = options.controller.step_current_node(
visualize_low_level_steps=True)
features, R, terminal, info = options.controller.\
step_current_node(visualize_low_level_steps=True)
episode_reward += R
print('Intermediate Reward: %f (ego x = %f)' %
(R, options.env.vehs[0].x))
......
......@@ -5,17 +5,17 @@ from model_checker.parser import Parser, Errors
class LTLPropertyBase(object):
"""This is a base class that contains information of an LTL property.
It encapsulates the model-checking part (see check / check_incremental),
and contains additional information. The subclass needs to describe
specific APdict to be used.
It encapsulates the model-checking part (see check /
check_incremental), and contains additional information. The
subclass needs to describe specific APdict to be used.
"""
#: The atomic propositions dict you must set in the subclass.
APdict = None
def __init__(self, LTL_str, penalty, enabled=True):
"""Constructor for LTLPropertyBase.
Assumes property does not change, but property may be applied to multiple traces.
"""Constructor for LTLPropertyBase. Assumes property does not change,
but property may be applied to multiple traces.
Args:
LTL_str: the human readable string representation of the LTL property
......@@ -42,8 +42,10 @@ class LTLPropertyBase(object):
self.result = Parser.UNDECIDED
def reset_property(self):
"""Resets existing property so that it can be applied to a new sequence of states.
Assumes init_property or check were previously called.
"""Resets existing property so that it can be applied to a new sequence
of states.
Assumes init_property or check were previously called.
"""
self.parser.ResetProperty()
self.result = Parser.UNDECIDED
......@@ -64,11 +66,11 @@ class LTLPropertyBase(object):
def check_incremental(self, state):
"""Checks an initialised property w.r.t. the next state in a trace.
Assumes init_property or check were previously called.
Assumes init_property or check were previously called.
Args:
state: next state (an integer)
Returns:
incremental result, in {TRUE, FALSE, UNDECIDED}
"""
......
......@@ -2,9 +2,9 @@ class Bits(object):
"""A bit-control class that allows us bit-wise manipulation as shown in the
example::
bits = Bits()
bits[0] = False
bits[2] = bits[0]
bits = Bits()
bits[0] = False
bits[2] = bits[0]
"""
def __init__(self, value=0):
......@@ -35,8 +35,9 @@ class Bits(object):
class AtomicPropositionsBase(Bits):
"""An AP-control base class for AP-wise manipulation.
the dictionary APdict and its length APdict_len has
to be given in the subclass
the dictionary APdict and its length APdict_len has to be given in
the subclass
"""
APdict = None
......
......@@ -191,9 +191,11 @@ class Parser(object):
TRUE = 2
def Check_old(self, propscanner, trace):
"""Deprecated method to check an entire trace with a new property Scanner.
Includes SetProperty, which includes ResetProperty.
"""
"""Deprecated method to check an entire trace with a new property
Scanner.
Includes SetProperty, which includes ResetProperty.
"""
self.SetProperty(propscanner)
for state in trace:
result = self.CheckIncremental(state)
......@@ -201,9 +203,11 @@ class Parser(object):
return result
def Check(self, trace):
"""Checks an entire trace w.r.t. an existing property Scanner.
Includes ResetProperty, but not SetProperty.
"""
"""Checks an entire trace w.r.t.
an existing property Scanner. Includes ResetProperty, but not
SetProperty.
"""
self.ResetProperty()
for state in trace:
result = self.CheckIncremental(state)
......@@ -211,23 +215,23 @@ class Parser(object):
return result
def SetProperty(self, propscanner):
"""Sets the property Scanner that tokenizes the property.
"""
"""Sets the property Scanner that tokenizes the property."""
self.scanner = propscanner
self.ResetProperty()
def ResetProperty(self):
"""Re-iniitializes an existing property Scanner.
"""
"""Re-iniitializes an existing property Scanner."""
self.step = 0
self.maxfactor = -1
self.start = self.scanner.t = self.scanner.tokens
def CheckIncremental(self, state):
"""Checks a new state w.r.t. an existing property Scanner.
Constructs a trace from new states using a new or previous trace list.
If a previous trace list is used, states after index self.step are not valid.
"""
"""Checks a new state w.r.t.
an existing property Scanner. Constructs a trace from new states
using a new or previous trace list. If a previous trace list is
used, states after index self.step are not valid.
"""
if len(self.trace) <= self.step:
self.trace.append(state)
else:
......
......@@ -4,8 +4,7 @@ from .AP_dict import AP_dict_simple_intersection
# TODO: classes.py is not a good name. Spliti the two classes into separate files and rename them.
class AtomicPropositions(AtomicPropositionsBase):
""" An AP-control class for AP-wise manipulation as shown in the
example:
"""An AP-control class for AP-wise manipulation as shown in the example:
APs = AtomicPropositions()
APs[0] = False # this is same as
......@@ -15,16 +14,15 @@ class AtomicPropositions(AtomicPropositionsBase):
Requires:
Index in [...] is an integer should be in the range
{0,1,2, ..., AP_dict_len}.
"""
APdict = AP_dict_simple_intersection
class LTLProperty(LTLPropertyBase):
""" is a class that contains information of an LTL property in
simple_intersection road scenario.
"""is a class that contains information of an LTL property in
simple_intersection road scenario.
It encapsulates the model-checking part and contains additional
information.
It encapsulates the model-checking part and contains additional
information.
"""
APdict = AP_dict_simple_intersection
......@@ -5,8 +5,7 @@ from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSControlle
class OptionsGraph:
"""
Represent the options graph as a graph like structure. The configuration
"""Represent the options graph as a graph like structure. The configuration
is specified in a json file and consists of the following specific values:
* nodes: dictionary of policy node aliases -> maneuver classes
......@@ -18,8 +17,7 @@ class OptionsGraph:
visualize_low_level_steps = False
def __init__(self, json_path, env_class, *env_args, **env_kwargs):
"""
Constructor for the options graph.
"""Constructor for the options graph.
Args:
json_path: path to the json file which stores the configuration
......@@ -78,9 +76,9 @@ class OptionsGraph:
"Controller to be used not specified")
def step(self, option):
"""Complete an episode using specified option. This assumes that
the manager's env is at a place where the option's initiation
condition is met.
"""Complete an episode using specified option. This assumes that the
manager's env is at a place where the option's initiation condition is
met.
Args:
option: index of high level option to be executed
......@@ -111,8 +109,9 @@ class OptionsGraph:
return self.env.reset()
def set_controller_policy(self, policy):
"""Sets the trained controller policy as a function which takes in feature vector and returns
an option index(int). By default, trained_policy is None
"""Sets the trained controller policy as a function which takes in
feature vector and returns an option index(int). By default,
trained_policy is None.
Args:
policy: a function which takes in feature vector and returns an option index(int)
......@@ -120,7 +119,7 @@ class OptionsGraph:
self.controller.set_trained_policy(policy)
def set_controller_args(self, **kwargs):
"""Sets custom arguments depending on the chosen controller
"""Sets custom arguments depending on the chosen controller.
Args:
**kwargs: dictionary of args and values
......@@ -128,9 +127,11 @@ class OptionsGraph:
self.controller.set_controller_args(**kwargs)
def execute_controller_policy(self):
"""Performs low level steps and transitions to other nodes using controller transition policy.
"""Performs low level steps and transitions to other nodes using
controller transition policy.
Returns state after one step, step reward, episode_termination_flag, info
Returns state after one step, step reward,
episode_termination_flag, info
"""
if self.controller.can_transition():
self.controller.do_transition()
......@@ -138,7 +139,8 @@ class OptionsGraph:
return self.controller.low_level_step_current_node()
def set_current_node(self, node_alias):
"""Sets the current node for controller. Used for training/testing a particular node
"""Sets the current node for controller. Used for training/testing a
particular node.
Args:
node_alias: alias of the node as per config file
......
......@@ -172,8 +172,9 @@ class ManeuverBase(EpisodicEnvBase):
return self._features_dim_reduction(features), reward, terminal, info
def get_reduced_feature_length(self):
"""
get the length of the feature tuple after applying _features_dim_reduction.