Commit f428cd41 authored by Ashish Gaurav's avatar Ashish Gaurav

run docformatter

parent 72e44d55
...@@ -48,10 +48,9 @@ class PPO2Agent(LearnerBase): ...@@ -48,10 +48,9 @@ class PPO2Agent(LearnerBase):
return MlpPolicy return MlpPolicy
def create_agent(self, policy, tensorboard): def create_agent(self, policy, tensorboard):
"""Creates a PPO agent """Creates a PPO agent.
Returns: Returns: stable_baselines PPO2 object
stable_baselines PPO2 object
""" """
if tensorboard: if tensorboard:
return PPO2( return PPO2(
......
...@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase): ...@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase):
setattr(self, prop, kwargs.get(prop, default)) setattr(self, prop, kwargs.get(prop, default))
def can_transition(self): def can_transition(self):
"""Returns boolean signifying whether we can transition. To be """Returns boolean signifying whether we can transition.
implemented in subclass.
To be implemented in subclass.
""" """
raise NotImplemented(self.__class__.__name__ + \ raise NotImplemented(self.__class__.__name__ + \
"can_transition is not implemented.") "can_transition is not implemented.")
def do_transition(self, observation): def do_transition(self, observation):
"""Do a transition, assuming we can transition. To be """Do a transition, assuming we can transition. To be implemented in
implemented in subclass. subclass.
Args: Args:
observation: final observation from episodic step observation: final observation from episodic step
...@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase): ...@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase):
"do_transition is not implemented.") "do_transition is not implemented.")
def set_current_node(self, node_alias): def set_current_node(self, node_alias):
"""Sets the current node which is being executed """Sets the current node which is being executed.
Args: Args:
node: node alias of the node to be set node: node alias of the node to be set
......
...@@ -431,8 +431,8 @@ class DQNAgentOverOptions(DQNAgent): ...@@ -431,8 +431,8 @@ class DQNAgentOverOptions(DQNAgent):
self.low_level_policies.keys()) self.low_level_policies.keys())
def __get_invalid_node_indices(self): def __get_invalid_node_indices(self):
"""Returns a list of option indices that are invalid according to initiation conditions. """Returns a list of option indices that are invalid according to
""" initiation conditions."""
invalid_node_indices = list() invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases): for index, option_alias in enumerate(self.low_level_policy_aliases):
self.low_level_policies[option_alias].reset_maneuver() self.low_level_policies[option_alias].reset_maneuver()
......
...@@ -20,8 +20,8 @@ class ManualPolicy(ControllerBase): ...@@ -20,8 +20,8 @@ class ManualPolicy(ControllerBase):
self.adj = transition_adj self.adj = transition_adj
def _transition(self): def _transition(self):
"""Check if the current node's termination condition is met and if """Check if the current node's termination condition is met and if it
it is possible to transition to another node, i.e. its initiation is possible to transition to another node, i.e. its initiation
condition is met. This is an internal function. condition is met. This is an internal function.
Returns the new node if a transition can happen, None otherwise Returns the new node if a transition can happen, None otherwise
......
...@@ -4,9 +4,8 @@ import pickle ...@@ -4,9 +4,8 @@ import pickle
class MCTSLearner(ControllerBase): class MCTSLearner(ControllerBase):
"""Monte Carlo Tree Search implementation using the UCB1 and """Monte Carlo Tree Search implementation using the UCB1 and progressive
progressive widening approach as explained in Paxton et al (2017). widening approach as explained in Paxton et al (2017)."""
"""
_ucb_vals = set() _ucb_vals = set()
...@@ -76,8 +75,8 @@ class MCTSLearner(ControllerBase): ...@@ -76,8 +75,8 @@ class MCTSLearner(ControllerBase):
self.new_node_num = to_restore['new_node_num'] self.new_node_num = to_restore['new_node_num']
def _create_node(self, low_level_policy): def _create_node(self, low_level_policy):
"""Create the node associated with curr_node_num, using the """Create the node associated with curr_node_num, using the given low
given low level policy. level policy.
Args: Args:
low_level_policy: the option's alias low_level_policy: the option's alias
...@@ -92,11 +91,10 @@ class MCTSLearner(ControllerBase): ...@@ -92,11 +91,10 @@ class MCTSLearner(ControllerBase):
return created_node_num, {"policy": low_level_policy} return created_node_num, {"policy": low_level_policy}
def _to_discrete(self, observation): def _to_discrete(self, observation):
"""Converts observation to a discrete observation tuple. Also """Converts observation to a discrete observation tuple. Also append
append (a) whether we are following a vehicle, and (b) whether (a) whether we are following a vehicle, and (b) whether there is a
there is a vehicle in the opposite lane in the approximately vehicle in the opposite lane in the approximately the same x position.
the same x position. These values will be useful for Follow These values will be useful for Follow and ChangeLane maneuvers.
and ChangeLane maneuvers.
Args: Args:
observation: observation tuple from the environment observation: observation tuple from the environment
...@@ -137,9 +135,9 @@ class MCTSLearner(ControllerBase): ...@@ -137,9 +135,9 @@ class MCTSLearner(ControllerBase):
def _get_visitation_count(self, observation, option=None): def _get_visitation_count(self, observation, option=None):
"""Finds the visitation count of the discrete form of the observation. """Finds the visitation count of the discrete form of the observation.
If discrete observation not found, then inserted into self.N with If discrete observation not found, then inserted into self.N with value
value 0. Auto converts the observation into discrete form. If option 0. Auto converts the observation into discrete form. If option is not
is not None, then this uses self.M instead of self.N None, then this uses self.M instead of self.N.
Args: Args:
observation: observation tuple from the environment observation: observation tuple from the environment
...@@ -177,8 +175,9 @@ class MCTSLearner(ControllerBase): ...@@ -177,8 +175,9 @@ class MCTSLearner(ControllerBase):
def _get_possible_options(self): def _get_possible_options(self):
"""Returns a set of options that can be taken from the current node. """Returns a set of options that can be taken from the current node.
Goes through adjacency set of current node and finds which next nodes'
initiation condition is met. Goes through adjacency set of current node and finds which next
nodes' initiation condition is met.
""" """
all_options = set(self.low_level_policies.keys()) all_options = set(self.low_level_policies.keys())
...@@ -211,9 +210,9 @@ class MCTSLearner(ControllerBase): ...@@ -211,9 +210,9 @@ class MCTSLearner(ControllerBase):
return visited_aliases return visited_aliases
def _ucb_adjusted_q(self, observation, C=1): def _ucb_adjusted_q(self, observation, C=1):
"""Computes Q_star(observation, option_i) plus the UCB term, which """Computes Q_star(observation, option_i) plus the UCB term, which is
is C*[predictor(observation, option_i)]/[1+N(observation, option_i)], C*[predictor(observation, option_i)]/[1+N(observation, option_i)], for
for all option_i in the adjacency set of the current node. all option_i in the adjacency set of the current node.
Args: Args:
observation: observation tuple from the environment observation: observation tuple from the environment
...@@ -257,14 +256,14 @@ class MCTSLearner(ControllerBase): ...@@ -257,14 +256,14 @@ class MCTSLearner(ControllerBase):
return sum_rewards / (1 + self._get_visitation_count(observation)) return sum_rewards / (1 + self._get_visitation_count(observation))
def _select(self, observation, depth=0, visualize=False): def _select(self, observation, depth=0, visualize=False):
"""MCTS selection function. For representation, we only use """MCTS selection function. For representation, we only use the
the discrete part of the observation. discrete part of the observation.
Args: Args:
observation: observation tuple from the environment observation: observation tuple from the environment
depth: current depth, starts from root node, hence 0 by default depth: current depth, starts from root node, hence 0 by default
visualize: whether or not to visualize low level steps visualize: whether or not to visualize low level steps
Returns the sum of values from the given observation. Returns the sum of values from the given observation.
""" """
...@@ -316,8 +315,8 @@ class MCTSLearner(ControllerBase): ...@@ -316,8 +315,8 @@ class MCTSLearner(ControllerBase):
return self._value(observation), all_ep_R + episode_R return self._value(observation), all_ep_R + episode_R
def traverse(self, observation, visualize=False): def traverse(self, observation, visualize=False):
"""Do a complete traversal from root to leaf. Assumes the """Do a complete traversal from root to leaf. Assumes the environment
environment is reset and we are at the root node. is reset and we are at the root node.
Args: Args:
observation: observation from the environment observation: observation from the environment
...@@ -329,13 +328,13 @@ class MCTSLearner(ControllerBase): ...@@ -329,13 +328,13 @@ class MCTSLearner(ControllerBase):
return self._select(observation, visualize=visualize) return self._select(observation, visualize=visualize)
def do_transition(self, observation, visualize=False): def do_transition(self, observation, visualize=False):
"""Do a transition using UCB metric, with the latest observation """Do a transition using UCB metric, with the latest observation from
from the episodic step. the episodic step.
Args: Args:
observation: final observation from episodic step observation: final observation from episodic step
visualize: whether or not to visualize low level steps visualize: whether or not to visualize low level steps
Returns o_star using UCB metric Returns o_star using UCB metric
""" """
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
class OnlineMCTSController(ControllerBase): class OnlineMCTSController(ControllerBase):
"""Online MCTS""" """Online MCTS."""
def __init__(self, env, low_level_policies, start_node_alias): def __init__(self, env, low_level_policies, start_node_alias):
"""Constructor for manual policy execution. """Constructor for manual policy execution.
......
class PolicyBase: class PolicyBase:
"""Abstract policy base from which every policy backend is defined """Abstract policy base from which every policy backend is defined and
and inherited.""" inherited."""
...@@ -3,23 +3,20 @@ from model_checker import Parser ...@@ -3,23 +3,20 @@ from model_checker import Parser
class GymCompliantEnvBase: class GymCompliantEnvBase:
def step(self, action): def step(self, action):
""" Gym compliant step function which """Gym compliant step function which will be implemented in the
will be implemented in the subclass. subclass."""
"""
raise NotImplemented(self.__class__.__name__ + raise NotImplemented(self.__class__.__name__ +
"step is not implemented.") "step is not implemented.")
def reset(self): def reset(self):
""" Gym compliant reset function which """Gym compliant reset function which will be implemented in the
will be implemented in the subclass. subclass."""
"""
raise NotImplemented(self.__class__.__name__ + raise NotImplemented(self.__class__.__name__ +
"reset is not implemented.") "reset is not implemented.")
def render(self): def render(self):
""" Gym compliant step function which """Gym compliant step function which will be implemented in the
will be implemented in the subclass. subclass."""
"""
raise NotImplemented(self.__class__.__name__ + raise NotImplemented(self.__class__.__name__ +
"render is not implemented.") "render is not implemented.")
...@@ -49,7 +46,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -49,7 +46,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
def _init_LTL_preconditions(self): def _init_LTL_preconditions(self):
"""Initialize the LTL preconditions (self._LTL_preconditions).. """Initialize the LTL preconditions (self._LTL_preconditions)..
in the subclass.
in the subclass.
""" """
return return
...@@ -72,9 +70,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -72,9 +70,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
self._r_terminal = None self._r_terminal = None
def _terminal_reward_superposition(self, r_obs): def _terminal_reward_superposition(self, r_obs):
"""Calculate the next value when observing "obs," from the prior """Calculate the next value when observing "obs," from the prior using
using min, max, or summation depending on the superposition_type. min, max, or summation depending on the superposition_type."""
"""
if r_obs is None: if r_obs is None:
return return
...@@ -133,7 +130,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -133,7 +130,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return violate, info return violate, info
def current_model_checking_result(self): def current_model_checking_result(self):
"""Returns whether or not any of the conditions is currently violated.""" """Returns whether or not any of the conditions is currently
violated."""
for LTL_precondition in self._LTL_preconditions: for LTL_precondition in self._LTL_preconditions:
if LTL_precondition.result == Parser.FALSE: if LTL_precondition.result == Parser.FALSE:
...@@ -155,9 +153,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -155,9 +153,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
@property @property
def termination_condition(self): def termination_condition(self):
"""In the subclass, specify the condition for termination of the episode """In the subclass, specify the condition for termination of the
(or the maneuver). episode (or the maneuver)."""
"""
if self._terminate_in_goal and self.goal_achieved: if self._terminate_in_goal and self.goal_achieved:
return True return True
......
class RoadEnv: class RoadEnv:
"""The generic road env """The generic road env.
TODO: Implement this generic road env for plugging-in other road envs. TODO: Implement this generic road env for plugging-in other road envs.
TODO: roadEnv also having a step() function can cause a problem. TODO: roadEnv also having a step() function can cause a problem.
""" """
\ No newline at end of file
...@@ -165,9 +165,7 @@ class Features(object): ...@@ -165,9 +165,7 @@ class Features(object):
return self.other_vehs[index] return self.other_vehs[index]
def get_features_tuple(self): def get_features_tuple(self):
""" """continuous + discrete features."""
continuous + discrete features
"""
feature = self.con_ego + self.dis_ego feature = self.con_ego + self.dis_ego
for other_veh in self.other_vehs: for other_veh in self.other_vehs:
......
...@@ -14,7 +14,7 @@ class Rectangle(Shape): ...@@ -14,7 +14,7 @@ class Rectangle(Shape):
"""Rectangle using OpenGL.""" """Rectangle using OpenGL."""
def __init__(self, xmin, xmax, ymin, ymax, color=(0, 0, 0, 255)): def __init__(self, xmin, xmax, ymin, ymax, color=(0, 0, 0, 255)):
""" Constructor for Rectangle. """Constructor for Rectangle.
The rectangle has the four points (xmin, ymin), (xmin, ymax), The rectangle has the four points (xmin, ymin), (xmin, ymax),
(xmax, ymax), (xmax, ymin). (xmax, ymax), (xmax, ymin).
......
...@@ -150,7 +150,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -150,7 +150,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return self.vehs[EGO_INDEX] return self.vehs[EGO_INDEX]
def generate_scenario(self, **kwargs): def generate_scenario(self, **kwargs):
"""Randomly generate a road scenario with """Randomly generate a road scenario with.
"the N-number of vehicles + an ego vehicle" "the N-number of vehicles + an ego vehicle"
...@@ -609,9 +609,9 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -609,9 +609,9 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
@staticmethod @staticmethod
def __regulate_veh_speed(veh, dist, v_max_multiplier): def __regulate_veh_speed(veh, dist, v_max_multiplier):
"""This private method set and regulate the speed of the """This private method set and regulate the speed of the vehicle in a
vehicle in a way that it never requires to travel the given way that it never requires to travel the given distance for complete
distance for complete stop. stop.
Args: Args:
veh: the vehicle reference. veh: the vehicle reference.
...@@ -920,9 +920,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -920,9 +920,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
self.ego.theta) > self.max_ego_theta self.ego.theta) > self.max_ego_theta
def _check_collisions(self): def _check_collisions(self):
"""returns True when collision happens, False if not. """returns True when collision happens, False if not."""
"""
self.__ego_collision_happened = self.check_ego_collision() or \ self.__ego_collision_happened = self.check_ego_collision() or \
self.is_ego_off_road() self.is_ego_off_road()
self.__others_collision_happened = self.check_other_veh_collisions() self.__others_collision_happened = self.check_other_veh_collisions()
...@@ -1004,8 +1002,8 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1004,8 +1002,8 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return True if (h_offroad and v_offroad) else False return True if (h_offroad and v_offroad) else False
def cost(self, u): def cost(self, u):
"""Calculate the driving cost of the ego, i.e., """Calculate the driving cost of the ego, i.e., the negative reward for
the negative reward for the ego-driving. the ego-driving.
Args: Args:
u: the low-level input (a, dot_psi) to the ego vehicle. u: the low-level input (a, dot_psi) to the ego vehicle.
...@@ -1045,8 +1043,9 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1045,8 +1043,9 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
#TODO: Move this to utilities #TODO: Move this to utilities
def normalize_tuple(self, vec, scale_factor=10): def normalize_tuple(self, vec, scale_factor=10):
"""Normalizes each element in a tuple according to ranges defined in self.cost_normalization_ranges. """Normalizes each element in a tuple according to ranges defined in
Normalizes between 0 and 1. And the scales by scale_factor self.cost_normalization_ranges. Normalizes between 0 and 1. And the
scales by scale_factor.
Args: Args:
vec: The tuple to be normalized vec: The tuple to be normalized
...@@ -1070,8 +1069,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1070,8 +1069,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
return tuple(normalized_vec) return tuple(normalized_vec)
def get_features_tuple(self): def get_features_tuple(self):
""" """Get/calculate the features wrt. the current state variables.
Get/calculate the features wrt. the current state variables
Returns features tuple Returns features tuple
""" """
...@@ -1173,11 +1171,12 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1173,11 +1171,12 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
@property @property
def goal_achieved(self): def goal_achieved(self):
""" """A property from the base class which is True if the goal of the road
A property from the base class which is True if the goal scenario is achieved, otherwise False.
of the road scenario is achieved, otherwise False. This property is
used in both step of EpisodicEnvBase and the implementation of This property is used in both step of EpisodicEnvBase and the
the high-level reinforcement learning and execution. implementation of the high-level reinforcement learning and
execution.
""" """
return (self.ego.x >= rd.hlanes.end_pos) and \ return (self.ego.x >= rd.hlanes.end_pos) and \
...@@ -1220,9 +1219,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -1220,9 +1219,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
label.draw() label.draw()
def render(self): def render(self):
""" """Gym compliant render function."""
Gym compliant render function.
"""
if self.window is None: if self.window is None:
self.window = config_Pyglet(pyglet, NUM_HPIXELS, NUM_VPIXELS) self.window = config_Pyglet(pyglet, NUM_HPIXELS, NUM_VPIXELS)
......
...@@ -247,7 +247,8 @@ def draw_all_shapes(shapes): ...@@ -247,7 +247,8 @@ def draw_all_shapes(shapes):
def calculate_v_max(dist): def calculate_v_max(dist):
"""Calculate the maximum velocity you can reach at the given position ahead. """Calculate the maximum velocity you can reach at the given position
ahead.
Args: Args:
dist: the distance you travel from the current position. dist: the distance you travel from the current position.
...@@ -259,8 +260,8 @@ def calculate_v_max(dist): ...@@ -259,8 +260,8 @@ def calculate_v_max(dist):
def calculate_s(v): def calculate_s(v):
"""Calculate the distance traveling when the vehicle is maximally """Calculate the distance traveling when the vehicle is maximally de-
de-accelerating for a complete stop. accelerating for a complete stop.
Args: Args:
v: the current speed of the vehicle. v: the current speed of the vehicle.
......
...@@ -16,8 +16,8 @@ def high_level_policy_training(nb_steps=25000, ...@@ -16,8 +16,8 @@ def high_level_policy_training(nb_steps=25000,
visualize=False, visualize=False,
tensorboard=False, tensorboard=False,
save_path="highlevel_weights.h5f"): save_path="highlevel_weights.h5f"):
""" """Do RL of the high-level policy and test it.
Do RL of the high-level policy and test it.
Args: Args:
nb_steps: the number of steps to perform RL nb_steps: the number of steps to perform RL
load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs) load_weights: True if the pre-learned NN weights are loaded (for initializations of NNs)
......
...@@ -16,8 +16,8 @@ def low_level_policy_training(maneuver, ...@@ -16,8 +16,8 @@ def low_level_policy_training(maneuver,
visualize=False, visualize=False,
nb_episodes_for_test=10, nb_episodes_for_test=10,
tensorboard=False): tensorboard=False):
""" """Do RL of the low-level policy of the given maneuver and test it.
Do RL of the low-level policy of the given maneuver and test it.
Args: Args:
maneuver: the name of the maneuver defined in config.json (e.g., 'default'). maneuver: the name of the maneuver defined in config.json (e.g., 'default').
nb_steps: the number of steps to perform RL. nb_steps: the number of steps to perform RL.
......
...@@ -34,8 +34,8 @@ def mcts_training(nb_traversals, ...@@ -34,8 +34,8 @@ def mcts_training(nb_traversals,
visualize=False, visualize=False,
load_saved=False, load_saved=False,
save_file="mcts.pickle"): save_file="mcts.pickle"):
""" """Do RL of the low-level policy of the given maneuver and test it.
Do RL of the low-level policy of the given maneuver and test it.
Args: Args:
nb_traversals: number of MCTS traversals nb_traversals: number of MCTS traversals
save_every: save at every these many traversals save_every: save at every these many traversals
...@@ -91,8 +91,8 @@ def mcts_evaluation(nb_traversals, ...@@ -91,8 +91,8 @@ def mcts_evaluation(nb_traversals,
visualize=False, visualize=False,
save_file="mcts.pickle", save_file="mcts.pickle",
pretrained=False): pretrained=False):
""" """Do RL of the low-level policy of the given maneuver and test it.
Do RL of the low-level policy of the given maneuver and test it.
Args: Args:
nb_traversals: number of MCTS traversals nb_traversals: number of MCTS traversals
save_every: save at every these many traversals save_every: save at every these many traversals
...@@ -170,8 +170,8 @@ def online_mcts(nb_episodes=10): ...@@ -170,8 +170,8 @@ def online_mcts(nb_episodes=10):
first_time = False first_time = False
else: else:
print('Stepping through ...') print('Stepping through ...')
features, R, terminal, info = options.controller.step_current_node( features, R, terminal, info = options.controller.\
visualize_low_level_steps=True) step_current_node(visualize_low_level_steps=True)
episode_reward += R episode_reward += R
print('Intermediate Reward: %f (ego x = %f)' % print('Intermediate Reward: %f (ego x = %f)' %
(R, options.env.vehs[0].x)) (R, options.env.vehs[0].x))
...@@ -224,8 +224,8 @@ def evaluate_online_mcts(nb_episodes=20, nb_trials=5): ...@@ -224,8 +224,8 @@ def evaluate_online_mcts(nb_episodes=20, nb_trials=5):
first_time = False first_time = False
else: else:
print('Stepping through ...') print('Stepping through ...')
features, R, terminal, info = options.controller.step_current_node( features, R, terminal, info = options.controller.\
visualize_low_level_steps=True) step_current_node(visualize_low_level_steps=True)
episode_reward += R episode_reward += R
print('Intermediate Reward: %f (ego x = %f)' % print('Intermediate Reward: %f (ego x = %f)' %
(R, options.env.vehs[0].x)) (R, options.env.vehs[0].x))
......
...@@ -5,17 +5,17 @@ from model_checker.parser import Parser, Errors ...@@ -5,17 +5,17 @@ from model_checker.parser import Parser, Errors
class LTLPropertyBase(object): class LTLPropertyBase(object):
"""This is a base class that contains information of an LTL property. """This is a base class that contains information of an LTL property.
It encapsulates the model-checking part (see check / check_incremental), It encapsulates the model-checking part (see check /
and contains additional information. The subclass needs to describe check_incremental), and contains additional information. The
specific APdict to be used. subclass needs to describe specific APdict to be used.
""" """
#: The atomic propositions dict you must set in the subclass. #: The atomic propositions dict you must set in the subclass.
APdict = None APdict = None
def __init__(self, LTL_str, penalty, enabled=True): def __init__(self, LTL_str, penalty, enabled=True):
"""Constructor for LTLPropertyBase. """Constructor for LTLPropertyBase. Assumes property does not change,
Assumes property does not change, but property may be applied to multiple traces. but property may be applied to multiple traces.
Args: Args:
LTL_str: the human readable string representation of the LTL property LTL_str: the human readable string representation of the LTL property
...@@ -42,8 +42,10 @@ class LTLPropertyBase(object): ...@@ -42,8 +42,10 @@ class LTLPropertyBase(object):
self.result = Parser.UNDECIDED self.result = Parser.UNDECIDED
def reset_property(self): def reset_property(self):
"""Resets existing property so that it can be applied to a new sequence of states. """Resets existing property so that it can be applied to a new sequence
Assumes init_property or check were previously called. of states.
Assumes init_property or check were previously called.
""" """
self.parser.ResetProperty() self.parser.ResetProperty()
self.result = Parser.UNDECIDED self.result = Parser.UNDECIDED
...@@ -64,11 +66,11 @@ class LTLPropertyBase(object): ...@@ -64,11 +66,11 @@ class LTLPropertyBase(object):
def check_incremental(self, state): def check_incremental(self, state):