From f428cd41e145558caad124da3d9023e8d3982d55 Mon Sep 17 00:00:00 2001 From: Ashish Gaurav <agaurav77@yahoo.com> Date: Mon, 19 Nov 2018 18:18:47 -0500 Subject: [PATCH] run docformatter --- backends/baselines_learner.py | 5 +- backends/controller_base.py | 11 ++-- backends/kerasrl_learner.py | 4 +- backends/manual_policy.py | 4 +- backends/mcts_learner.py | 53 +++++++++---------- backends/online_mcts_controller.py | 2 +- backends/policy_base.py | 4 +- env/env_base.py | 31 +++++------ env/road_env.py | 6 +-- env/simple_intersection/features.py | 4 +- env/simple_intersection/shapes.py | 2 +- .../simple_intersection_env.py | 39 +++++++------- env/simple_intersection/utilities.py | 7 +-- high_level_policy_main.py | 4 +- low_level_policy_main.py | 4 +- mcts.py | 16 +++--- model_checker/LTL_property_base.py | 22 ++++---- model_checker/atomic_propositions_base.py | 11 ++-- model_checker/parser.py | 32 ++++++----- model_checker/simple_intersection/classes.py | 12 ++--- options/options_loader.py | 28 +++++----- options/simple_intersection/maneuver_base.py | 40 ++++++++------ ppo2_training.py | 3 +- 23 files changed, 175 insertions(+), 169 deletions(-) diff --git a/backends/baselines_learner.py b/backends/baselines_learner.py index 8a62830..ab4e4da 100644 --- a/backends/baselines_learner.py +++ b/backends/baselines_learner.py @@ -48,10 +48,9 @@ class PPO2Agent(LearnerBase): return MlpPolicy def create_agent(self, policy, tensorboard): - """Creates a PPO agent + """Creates a PPO agent. - Returns: - stable_baselines PPO2 object + Returns: stable_baselines PPO2 object """ if tensorboard: return PPO2( diff --git a/backends/controller_base.py b/backends/controller_base.py index b9bbc35..885f941 100644 --- a/backends/controller_base.py +++ b/backends/controller_base.py @@ -18,16 +18,17 @@ class ControllerBase(PolicyBase): setattr(self, prop, kwargs.get(prop, default)) def can_transition(self): - """Returns boolean signifying whether we can transition. To be - implemented in subclass. + """Returns boolean signifying whether we can transition. + + To be implemented in subclass. """ raise NotImplemented(self.__class__.__name__ + \ "can_transition is not implemented.") def do_transition(self, observation): - """Do a transition, assuming we can transition. To be - implemented in subclass. + """Do a transition, assuming we can transition. To be implemented in + subclass. Args: observation: final observation from episodic step @@ -37,7 +38,7 @@ class ControllerBase(PolicyBase): "do_transition is not implemented.") def set_current_node(self, node_alias): - """Sets the current node which is being executed + """Sets the current node which is being executed. Args: node: node alias of the node to be set diff --git a/backends/kerasrl_learner.py b/backends/kerasrl_learner.py index d8138bb..a98dc92 100644 --- a/backends/kerasrl_learner.py +++ b/backends/kerasrl_learner.py @@ -431,8 +431,8 @@ class DQNAgentOverOptions(DQNAgent): self.low_level_policies.keys()) def __get_invalid_node_indices(self): - """Returns a list of option indices that are invalid according to initiation conditions. - """ + """Returns a list of option indices that are invalid according to + initiation conditions.""" invalid_node_indices = list() for index, option_alias in enumerate(self.low_level_policy_aliases): self.low_level_policies[option_alias].reset_maneuver() diff --git a/backends/manual_policy.py b/backends/manual_policy.py index 4b891e7..6e59db1 100644 --- a/backends/manual_policy.py +++ b/backends/manual_policy.py @@ -20,8 +20,8 @@ class ManualPolicy(ControllerBase): self.adj = transition_adj def _transition(self): - """Check if the current node's termination condition is met and if - it is possible to transition to another node, i.e. its initiation + """Check if the current node's termination condition is met and if it + is possible to transition to another node, i.e. its initiation condition is met. This is an internal function. Returns the new node if a transition can happen, None otherwise diff --git a/backends/mcts_learner.py b/backends/mcts_learner.py index 3743dba..f9bd759 100644 --- a/backends/mcts_learner.py +++ b/backends/mcts_learner.py @@ -4,9 +4,8 @@ import pickle class MCTSLearner(ControllerBase): - """Monte Carlo Tree Search implementation using the UCB1 and - progressive widening approach as explained in Paxton et al (2017). - """ + """Monte Carlo Tree Search implementation using the UCB1 and progressive + widening approach as explained in Paxton et al (2017).""" _ucb_vals = set() @@ -76,8 +75,8 @@ class MCTSLearner(ControllerBase): self.new_node_num = to_restore['new_node_num'] def _create_node(self, low_level_policy): - """Create the node associated with curr_node_num, using the - given low level policy. + """Create the node associated with curr_node_num, using the given low + level policy. Args: low_level_policy: the option's alias @@ -92,11 +91,10 @@ class MCTSLearner(ControllerBase): return created_node_num, {"policy": low_level_policy} def _to_discrete(self, observation): - """Converts observation to a discrete observation tuple. Also - append (a) whether we are following a vehicle, and (b) whether - there is a vehicle in the opposite lane in the approximately - the same x position. These values will be useful for Follow - and ChangeLane maneuvers. + """Converts observation to a discrete observation tuple. Also append + (a) whether we are following a vehicle, and (b) whether there is a + vehicle in the opposite lane in the approximately the same x position. + These values will be useful for Follow and ChangeLane maneuvers. Args: observation: observation tuple from the environment @@ -137,9 +135,9 @@ class MCTSLearner(ControllerBase): def _get_visitation_count(self, observation, option=None): """Finds the visitation count of the discrete form of the observation. - If discrete observation not found, then inserted into self.N with - value 0. Auto converts the observation into discrete form. If option - is not None, then this uses self.M instead of self.N + If discrete observation not found, then inserted into self.N with value + 0. Auto converts the observation into discrete form. If option is not + None, then this uses self.M instead of self.N. Args: observation: observation tuple from the environment @@ -177,8 +175,9 @@ class MCTSLearner(ControllerBase): def _get_possible_options(self): """Returns a set of options that can be taken from the current node. - Goes through adjacency set of current node and finds which next nodes' - initiation condition is met. + + Goes through adjacency set of current node and finds which next + nodes' initiation condition is met. """ all_options = set(self.low_level_policies.keys()) @@ -211,9 +210,9 @@ class MCTSLearner(ControllerBase): return visited_aliases def _ucb_adjusted_q(self, observation, C=1): - """Computes Q_star(observation, option_i) plus the UCB term, which - is C*[predictor(observation, option_i)]/[1+N(observation, option_i)], - for all option_i in the adjacency set of the current node. + """Computes Q_star(observation, option_i) plus the UCB term, which is + C*[predictor(observation, option_i)]/[1+N(observation, option_i)], for + all option_i in the adjacency set of the current node. Args: observation: observation tuple from the environment @@ -257,14 +256,14 @@ class MCTSLearner(ControllerBase): return sum_rewards / (1 + self._get_visitation_count(observation)) def _select(self, observation, depth=0, visualize=False): - """MCTS selection function. For representation, we only use - the discrete part of the observation. + """MCTS selection function. For representation, we only use the + discrete part of the observation. Args: observation: observation tuple from the environment depth: current depth, starts from root node, hence 0 by default visualize: whether or not to visualize low level steps - + Returns the sum of values from the given observation. """ @@ -316,8 +315,8 @@ class MCTSLearner(ControllerBase): return self._value(observation), all_ep_R + episode_R def traverse(self, observation, visualize=False): - """Do a complete traversal from root to leaf. Assumes the - environment is reset and we are at the root node. + """Do a complete traversal from root to leaf. Assumes the environment + is reset and we are at the root node. Args: observation: observation from the environment @@ -329,13 +328,13 @@ class MCTSLearner(ControllerBase): return self._select(observation, visualize=visualize) def do_transition(self, observation, visualize=False): - """Do a transition using UCB metric, with the latest observation - from the episodic step. + """Do a transition using UCB metric, with the latest observation from + the episodic step. Args: observation: final observation from episodic step - visualize: whether or not to visualize low level steps - + visualize: whether or not to visualize low level steps + Returns o_star using UCB metric """ diff --git a/backends/online_mcts_controller.py b/backends/online_mcts_controller.py index 64343d8..a53d83b 100644 --- a/backends/online_mcts_controller.py +++ b/backends/online_mcts_controller.py @@ -5,7 +5,7 @@ import numpy as np class OnlineMCTSController(ControllerBase): - """Online MCTS""" + """Online MCTS.""" def __init__(self, env, low_level_policies, start_node_alias): """Constructor for manual policy execution. diff --git a/backends/policy_base.py b/backends/policy_base.py index ad1d3df..25bb5c3 100644 --- a/backends/policy_base.py +++ b/backends/policy_base.py @@ -1,3 +1,3 @@ class PolicyBase: - """Abstract policy base from which every policy backend is defined - and inherited.""" + """Abstract policy base from which every policy backend is defined and + inherited.""" diff --git a/env/env_base.py b/env/env_base.py index a68c066..4a8b9a5 100644 --- a/env/env_base.py +++ b/env/env_base.py @@ -3,23 +3,20 @@ from model_checker import Parser class GymCompliantEnvBase: def step(self, action): - """ Gym compliant step function which - will be implemented in the subclass. - """ + """Gym compliant step function which will be implemented in the + subclass.""" raise NotImplemented(self.__class__.__name__ + "step is not implemented.") def reset(self): - """ Gym compliant reset function which - will be implemented in the subclass. - """ + """Gym compliant reset function which will be implemented in the + subclass.""" raise NotImplemented(self.__class__.__name__ + "reset is not implemented.") def render(self): - """ Gym compliant step function which - will be implemented in the subclass. - """ + """Gym compliant step function which will be implemented in the + subclass.""" raise NotImplemented(self.__class__.__name__ + "render is not implemented.") @@ -49,7 +46,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): def _init_LTL_preconditions(self): """Initialize the LTL preconditions (self._LTL_preconditions).. - in the subclass. + + in the subclass. """ return @@ -72,9 +70,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): self._r_terminal = None def _terminal_reward_superposition(self, r_obs): - """Calculate the next value when observing "obs," from the prior - using min, max, or summation depending on the superposition_type. - """ + """Calculate the next value when observing "obs," from the prior using + min, max, or summation depending on the superposition_type.""" if r_obs is None: return @@ -133,7 +130,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): return violate, info def current_model_checking_result(self): - """Returns whether or not any of the conditions is currently violated.""" + """Returns whether or not any of the conditions is currently + violated.""" for LTL_precondition in self._LTL_preconditions: if LTL_precondition.result == Parser.FALSE: @@ -155,9 +153,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): @property def termination_condition(self): - """In the subclass, specify the condition for termination of the episode - (or the maneuver). - """ + """In the subclass, specify the condition for termination of the + episode (or the maneuver).""" if self._terminate_in_goal and self.goal_achieved: return True diff --git a/env/road_env.py b/env/road_env.py index 3a189ed..9743b58 100644 --- a/env/road_env.py +++ b/env/road_env.py @@ -1,6 +1,6 @@ class RoadEnv: - """The generic road env + """The generic road env. - TODO: Implement this generic road env for plugging-in other road envs. - TODO: roadEnv also having a step() function can cause a problem. + TODO: Implement this generic road env for plugging-in other road envs. + TODO: roadEnv also having a step() function can cause a problem. """ \ No newline at end of file diff --git a/env/simple_intersection/features.py b/env/simple_intersection/features.py index 9973cf6..2b70e51 100644 --- a/env/simple_intersection/features.py +++ b/env/simple_intersection/features.py @@ -165,9 +165,7 @@ class Features(object): return self.other_vehs[index] def get_features_tuple(self): - """ - continuous + discrete features - """ + """continuous + discrete features.""" feature = self.con_ego + self.dis_ego for other_veh in self.other_vehs: diff --git a/env/simple_intersection/shapes.py b/env/simple_intersection/shapes.py index 62f1566..d814290 100644 --- a/env/simple_intersection/shapes.py +++ b/env/simple_intersection/shapes.py @@ -14,7 +14,7 @@ class Rectangle(Shape): """Rectangle using OpenGL.""" def __init__(self, xmin, xmax, ymin, ymax, color=(0, 0, 0, 255)): - """ Constructor for Rectangle. + """Constructor for Rectangle. The rectangle has the four points (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin). diff --git a/env/simple_intersection/simple_intersection_env.py b/env/simple_intersection/simple_intersection_env.py index f37867f..1d61b91 100644 --- a/env/simple_intersection/simple_intersection_env.py +++ b/env/simple_intersection/simple_intersection_env.py @@ -150,7 +150,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): return self.vehs[EGO_INDEX] def generate_scenario(self, **kwargs): - """Randomly generate a road scenario with + """Randomly generate a road scenario with. "the N-number of vehicles + an ego vehicle" @@ -609,9 +609,9 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): @staticmethod def __regulate_veh_speed(veh, dist, v_max_multiplier): - """This private method set and regulate the speed of the - vehicle in a way that it never requires to travel the given - distance for complete stop. + """This private method set and regulate the speed of the vehicle in a + way that it never requires to travel the given distance for complete + stop. Args: veh: the vehicle reference. @@ -920,9 +920,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): self.ego.theta) > self.max_ego_theta def _check_collisions(self): - """returns True when collision happens, False if not. - - """ + """returns True when collision happens, False if not.""" self.__ego_collision_happened = self.check_ego_collision() or \ self.is_ego_off_road() self.__others_collision_happened = self.check_other_veh_collisions() @@ -1004,8 +1002,8 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): return True if (h_offroad and v_offroad) else False def cost(self, u): - """Calculate the driving cost of the ego, i.e., - the negative reward for the ego-driving. + """Calculate the driving cost of the ego, i.e., the negative reward for + the ego-driving. Args: u: the low-level input (a, dot_psi) to the ego vehicle. @@ -1045,8 +1043,9 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): #TODO: Move this to utilities def normalize_tuple(self, vec, scale_factor=10): - """Normalizes each element in a tuple according to ranges defined in self.cost_normalization_ranges. - Normalizes between 0 and 1. And the scales by scale_factor + """Normalizes each element in a tuple according to ranges defined in + self.cost_normalization_ranges. Normalizes between 0 and 1. And the + scales by scale_factor. Args: vec: The tuple to be normalized @@ -1070,8 +1069,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): return tuple(normalized_vec) def get_features_tuple(self): - """ - Get/calculate the features wrt. the current state variables + """Get/calculate the features wrt. the current state variables. Returns features tuple """ @@ -1173,11 +1171,12 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): @property def goal_achieved(self): - """ - A property from the base class which is True if the goal - of the road scenario is achieved, otherwise False. This property is - used in both step of EpisodicEnvBase and the implementation of - the high-level reinforcement learning and execution. + """A property from the base class which is True if the goal of the road + scenario is achieved, otherwise False. + + This property is used in both step of EpisodicEnvBase and the + implementation of the high-level reinforcement learning and + execution. """ return (self.ego.x >= rd.hlanes.end_pos) and \ @@ -1220,9 +1219,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): label.draw() def render(self): - """ - Gym compliant render function. - """ + """Gym compliant render function.""" if self.window is None: self.window = config_Pyglet(pyglet, NUM_HPIXELS, NUM_VPIXELS) diff --git a/env/simple_intersection/utilities.py b/env/simple_intersection/utilities.py index cffc663..9081f92 100644 --- a/env/simple_intersection/utilities.py +++ b/env/simple_intersection/utilities.py @@ -247,7 +247,8 @@ def draw_all_shapes(shapes): def calculate_v_max(dist): - """Calculate the maximum velocity you can reach at the given position ahead. + """Calculate the maximum velocity you can reach at the given position + ahead. Args: dist: the distance you travel from the current position. @@ -259,8 +260,8 @@ def calculate_v_max(dist): def calculate_s(v): - """Calculate the distance traveling when the vehicle is maximally - de-accelerating for a complete stop. + """Calculate the distance traveling when the vehicle is maximally de- + accelerating for a complete stop. Args: v: the current speed of the vehicle. diff --git a/high_level_policy_main.py b/high_level_policy_main.py index 6554c55..3c5a0f3 100644 --- a/high_level_policy_main.py +++ b/high_level_policy_main.py @@ -16,8 +16,8 @@ def high_level_policy_training(nb_steps=25000, visualize=False, tensorboard=False, save_path="highlevel_weights.h5f"): - """ - Do RL of the high-level policy and test it. + """Do RL of the high-level policy and test it. + Args: nb_steps: the number of steps to perform RL load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs) diff --git a/low_level_policy_main.py b/low_level_policy_main.py index 209f645..c6dd335 100644 --- a/low_level_policy_main.py +++ b/low_level_policy_main.py @@ -16,8 +16,8 @@ def low_level_policy_training(maneuver, visualize=False, nb_episodes_for_test=10, tensorboard=False): - """ - Do RL of the low-level policy of the given maneuver and test it. + """Do RL of the low-level policy of the given maneuver and test it. + Args: maneuver: the name of the maneuver defined in config.json (e.g., 'default'). nb_steps: the number of steps to perform RL. diff --git a/mcts.py b/mcts.py index 7c4fa22..e68356b 100644 --- a/mcts.py +++ b/mcts.py @@ -34,8 +34,8 @@ def mcts_training(nb_traversals, visualize=False, load_saved=False, save_file="mcts.pickle"): - """ - Do RL of the low-level policy of the given maneuver and test it. + """Do RL of the low-level policy of the given maneuver and test it. + Args: nb_traversals: number of MCTS traversals save_every: save at every these many traversals @@ -91,8 +91,8 @@ def mcts_evaluation(nb_traversals, visualize=False, save_file="mcts.pickle", pretrained=False): - """ - Do RL of the low-level policy of the given maneuver and test it. + """Do RL of the low-level policy of the given maneuver and test it. + Args: nb_traversals: number of MCTS traversals save_every: save at every these many traversals @@ -170,8 +170,8 @@ def online_mcts(nb_episodes=10): first_time = False else: print('Stepping through ...') - features, R, terminal, info = options.controller.step_current_node( - visualize_low_level_steps=True) + features, R, terminal, info = options.controller.\ + step_current_node(visualize_low_level_steps=True) episode_reward += R print('Intermediate Reward: %f (ego x = %f)' % (R, options.env.vehs[0].x)) @@ -224,8 +224,8 @@ def evaluate_online_mcts(nb_episodes=20, nb_trials=5): first_time = False else: print('Stepping through ...') - features, R, terminal, info = options.controller.step_current_node( - visualize_low_level_steps=True) + features, R, terminal, info = options.controller.\ + step_current_node(visualize_low_level_steps=True) episode_reward += R print('Intermediate Reward: %f (ego x = %f)' % (R, options.env.vehs[0].x)) diff --git a/model_checker/LTL_property_base.py b/model_checker/LTL_property_base.py index 304a306..4d5a9d8 100644 --- a/model_checker/LTL_property_base.py +++ b/model_checker/LTL_property_base.py @@ -5,17 +5,17 @@ from model_checker.parser import Parser, Errors class LTLPropertyBase(object): """This is a base class that contains information of an LTL property. - It encapsulates the model-checking part (see check / check_incremental), - and contains additional information. The subclass needs to describe - specific APdict to be used. + It encapsulates the model-checking part (see check / + check_incremental), and contains additional information. The + subclass needs to describe specific APdict to be used. """ #: The atomic propositions dict you must set in the subclass. APdict = None def __init__(self, LTL_str, penalty, enabled=True): - """Constructor for LTLPropertyBase. - Assumes property does not change, but property may be applied to multiple traces. + """Constructor for LTLPropertyBase. Assumes property does not change, + but property may be applied to multiple traces. Args: LTL_str: the human readable string representation of the LTL property @@ -42,8 +42,10 @@ class LTLPropertyBase(object): self.result = Parser.UNDECIDED def reset_property(self): - """Resets existing property so that it can be applied to a new sequence of states. - Assumes init_property or check were previously called. + """Resets existing property so that it can be applied to a new sequence + of states. + + Assumes init_property or check were previously called. """ self.parser.ResetProperty() self.result = Parser.UNDECIDED @@ -64,11 +66,11 @@ class LTLPropertyBase(object): def check_incremental(self, state): """Checks an initialised property w.r.t. the next state in a trace. - Assumes init_property or check were previously called. - + Assumes init_property or check were previously called. + Args: state: next state (an integer) - + Returns: incremental result, in {TRUE, FALSE, UNDECIDED} """ diff --git a/model_checker/atomic_propositions_base.py b/model_checker/atomic_propositions_base.py index 7e50d70..4012cdf 100644 --- a/model_checker/atomic_propositions_base.py +++ b/model_checker/atomic_propositions_base.py @@ -2,9 +2,9 @@ class Bits(object): """A bit-control class that allows us bit-wise manipulation as shown in the example:: - bits = Bits() - bits[0] = False - bits[2] = bits[0] + bits = Bits() + bits[0] = False + bits[2] = bits[0] """ def __init__(self, value=0): @@ -35,8 +35,9 @@ class Bits(object): class AtomicPropositionsBase(Bits): """An AP-control base class for AP-wise manipulation. - the dictionary APdict and its length APdict_len has - to be given in the subclass + + the dictionary APdict and its length APdict_len has to be given in + the subclass """ APdict = None diff --git a/model_checker/parser.py b/model_checker/parser.py index 8a415fe..8a6dc99 100644 --- a/model_checker/parser.py +++ b/model_checker/parser.py @@ -191,9 +191,11 @@ class Parser(object): TRUE = 2 def Check_old(self, propscanner, trace): - """Deprecated method to check an entire trace with a new property Scanner. - Includes SetProperty, which includes ResetProperty. - """ + """Deprecated method to check an entire trace with a new property + Scanner. + + Includes SetProperty, which includes ResetProperty. + """ self.SetProperty(propscanner) for state in trace: result = self.CheckIncremental(state) @@ -201,9 +203,11 @@ class Parser(object): return result def Check(self, trace): - """Checks an entire trace w.r.t. an existing property Scanner. - Includes ResetProperty, but not SetProperty. - """ + """Checks an entire trace w.r.t. + + an existing property Scanner. Includes ResetProperty, but not + SetProperty. + """ self.ResetProperty() for state in trace: result = self.CheckIncremental(state) @@ -211,23 +215,23 @@ class Parser(object): return result def SetProperty(self, propscanner): - """Sets the property Scanner that tokenizes the property. - """ + """Sets the property Scanner that tokenizes the property.""" self.scanner = propscanner self.ResetProperty() def ResetProperty(self): - """Re-iniitializes an existing property Scanner. - """ + """Re-iniitializes an existing property Scanner.""" self.step = 0 self.maxfactor = -1 self.start = self.scanner.t = self.scanner.tokens def CheckIncremental(self, state): - """Checks a new state w.r.t. an existing property Scanner. - Constructs a trace from new states using a new or previous trace list. - If a previous trace list is used, states after index self.step are not valid. - """ + """Checks a new state w.r.t. + + an existing property Scanner. Constructs a trace from new states + using a new or previous trace list. If a previous trace list is + used, states after index self.step are not valid. + """ if len(self.trace) <= self.step: self.trace.append(state) else: diff --git a/model_checker/simple_intersection/classes.py b/model_checker/simple_intersection/classes.py index 471d7da..4a88c83 100644 --- a/model_checker/simple_intersection/classes.py +++ b/model_checker/simple_intersection/classes.py @@ -4,8 +4,7 @@ from .AP_dict import AP_dict_simple_intersection # TODO: classes.py is not a good name. Spliti the two classes into separate files and rename them. class AtomicPropositions(AtomicPropositionsBase): - """ An AP-control class for AP-wise manipulation as shown in the - example: + """An AP-control class for AP-wise manipulation as shown in the example: APs = AtomicPropositions() APs[0] = False # this is same as @@ -15,16 +14,15 @@ class AtomicPropositions(AtomicPropositionsBase): Requires: Index in [...] is an integer should be in the range {0,1,2, ..., AP_dict_len}. - """ APdict = AP_dict_simple_intersection class LTLProperty(LTLPropertyBase): - """ is a class that contains information of an LTL property in - simple_intersection road scenario. + """is a class that contains information of an LTL property in + simple_intersection road scenario. - It encapsulates the model-checking part and contains additional - information. + It encapsulates the model-checking part and contains additional + information. """ APdict = AP_dict_simple_intersection diff --git a/options/options_loader.py b/options/options_loader.py index 0befa6b..f58554f 100644 --- a/options/options_loader.py +++ b/options/options_loader.py @@ -5,8 +5,7 @@ from backends import RLController, DDPGLearner, MCTSLearner, OnlineMCTSControlle class OptionsGraph: - """ - Represent the options graph as a graph like structure. The configuration + """Represent the options graph as a graph like structure. The configuration is specified in a json file and consists of the following specific values: * nodes: dictionary of policy node aliases -> maneuver classes @@ -18,8 +17,7 @@ class OptionsGraph: visualize_low_level_steps = False def __init__(self, json_path, env_class, *env_args, **env_kwargs): - """ - Constructor for the options graph. + """Constructor for the options graph. Args: json_path: path to the json file which stores the configuration @@ -78,9 +76,9 @@ class OptionsGraph: "Controller to be used not specified") def step(self, option): - """Complete an episode using specified option. This assumes that - the manager's env is at a place where the option's initiation - condition is met. + """Complete an episode using specified option. This assumes that the + manager's env is at a place where the option's initiation condition is + met. Args: option: index of high level option to be executed @@ -111,8 +109,9 @@ class OptionsGraph: return self.env.reset() def set_controller_policy(self, policy): - """Sets the trained controller policy as a function which takes in feature vector and returns - an option index(int). By default, trained_policy is None + """Sets the trained controller policy as a function which takes in + feature vector and returns an option index(int). By default, + trained_policy is None. Args: policy: a function which takes in feature vector and returns an option index(int) @@ -120,7 +119,7 @@ class OptionsGraph: self.controller.set_trained_policy(policy) def set_controller_args(self, **kwargs): - """Sets custom arguments depending on the chosen controller + """Sets custom arguments depending on the chosen controller. Args: **kwargs: dictionary of args and values @@ -128,9 +127,11 @@ class OptionsGraph: self.controller.set_controller_args(**kwargs) def execute_controller_policy(self): - """Performs low level steps and transitions to other nodes using controller transition policy. + """Performs low level steps and transitions to other nodes using + controller transition policy. - Returns state after one step, step reward, episode_termination_flag, info + Returns state after one step, step reward, + episode_termination_flag, info """ if self.controller.can_transition(): self.controller.do_transition() @@ -138,7 +139,8 @@ class OptionsGraph: return self.controller.low_level_step_current_node() def set_current_node(self, node_alias): - """Sets the current node for controller. Used for training/testing a particular node + """Sets the current node for controller. Used for training/testing a + particular node. Args: node_alias: alias of the node as per config file diff --git a/options/simple_intersection/maneuver_base.py b/options/simple_intersection/maneuver_base.py index 46de914..78f7a63 100644 --- a/options/simple_intersection/maneuver_base.py +++ b/options/simple_intersection/maneuver_base.py @@ -172,8 +172,9 @@ class ManeuverBase(EpisodicEnvBase): return self._features_dim_reduction(features), reward, terminal, info def get_reduced_feature_length(self): - """ - get the length of the feature tuple after applying _features_dim_reduction. + """get the length of the feature tuple after applying + _features_dim_reduction. + :return: """ return len(self.get_reduced_features_tuple()) @@ -235,8 +236,10 @@ class ManeuverBase(EpisodicEnvBase): self.env.render() # Do nothing but calling self.env.render() def set_low_level_trained_policy(self, trained_policy): - """Sets the trained policy as a function which takes in feature vector and returns - an action (a, dot_psi). By default, trained_policy is None + """Sets the trained policy as a function which takes in feature vector + and returns an action (a, dot_psi). + + By default, trained_policy is None """ self.trained_policy = trained_policy @@ -254,8 +257,7 @@ class ManeuverBase(EpisodicEnvBase): @staticmethod def _features_dim_reduction(features_tuple): - """ - Reduce the dimension of the features in step and reset. + """Reduce the dimension of the features in step and reset. Param: features_tuple: a tuple obtained by e.g., self.env.get_features_tuple() Return: the reduced features tuple (by default, it returns features_tuple itself. @@ -264,9 +266,11 @@ class ManeuverBase(EpisodicEnvBase): # TODO: Determine whether this method depends on the external features_tuple or for simplicity, define and use a features_tuple within the class. def low_level_policy(self, reduced_features_tuple): - """the low level policy as a map from a feature vector to an action - (a, dot_psi). By default, it'll call low_level_manual_policy below - if it's implemented in the subclass. + """the low level policy as a map from a feature vector to an action (a, + dot_psi). + + By default, it'll call low_level_manual_policy below if it's + implemented in the subclass. """ if self.trained_policy is None: return self._low_level_manual_policy() @@ -274,9 +278,12 @@ class ManeuverBase(EpisodicEnvBase): return self.trained_policy(reduced_features_tuple) def _low_level_manual_policy(self): - """the manually-defined low level policy as a map from a feature vector to an action - (a, dot_psi). _low_level_policy will call this manual policy unless modified in the subclass. - Implement this in the subclass whenever necessary. + """the manually-defined low level policy as a map from a feature vector + to an action (a, dot_psi). + + _low_level_policy will call this manual policy unless modified + in the subclass. Implement this in the subclass whenever + necessary. """ raise NotImplemented(self.__class__.__name__ + "._low_Level_manual_policy is not implemented.") @@ -295,9 +302,9 @@ class ManeuverBase(EpisodicEnvBase): enable_LTL_preconditions=True, timeout=np.infty, **kwargs): - """generates the scenario for low-level policy learning and validation. This method - will be used in generate_learning_scenario and generate_validation_scenario in - the subclasses. + """generates the scenario for low-level policy learning and validation. + This method will be used in generate_learning_scenario and + generate_validation_scenario in the subclasses. Param: enable_LTL_preconditions: whether to enable LTL preconditions in the maneuver or not @@ -356,7 +363,6 @@ class ManeuverBase(EpisodicEnvBase): return # do nothing unless specified in the subclass def _update_param(self): - """Update the parameters in the gym-compliant 'step' method above - """ + """Update the parameters in the gym-compliant 'step' method above.""" return # do nothing unless specified in the subclass diff --git a/ppo2_training.py b/ppo2_training.py index 5c5609c..269b7fd 100644 --- a/ppo2_training.py +++ b/ppo2_training.py @@ -29,7 +29,8 @@ if __name__ == "__main__": "right_weights.h5f" ) # Save the NN weights for reloading them in the future. - # Uncomment this after training to use trained model. Comment this to use manually defined policy + # Uncomment this after training to use trained model. + # Comment this to use manually defined policy #agent.load_weights("right_weights.h5f") print("Testing model...") -- GitLab