Commit f2171d2c authored by Aravind Balakrishnan's avatar Aravind Balakrishnan

Merge branch 'formatting' into 'master'

Formatting

See merge request !3
parents e1fdb162 7abc600f
*.pyc
*.py~
__pycache__
......@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
self.log_path = log_path
self.env = DummyVecEnv([lambda: env]) #PPO2 requried a vectorized environment for parallel training
self.env = DummyVecEnv([
lambda: env
]) #PPO2 requried a vectorized environment for parallel training
self.agent_model = self.create_agent(policy, tensorboard)
def get_default_policy(self):
......@@ -46,13 +48,13 @@ class PPO2Agent(LearnerBase):
return MlpPolicy
def create_agent(self, policy, tensorboard):
"""Creates a PPO agent
"""Creates a PPO agent.
Returns:
stable_baselines PPO2 object
Returns: stable_baselines PPO2 object
"""
if tensorboard:
return PPO2(policy, self.env, verbose=1, tensorboard_log=self.log_path)
return PPO2(
policy, self.env, verbose=1, tensorboard_log=self.log_path)
else:
return PPO2(policy, self.env, verbose=1)
......@@ -100,7 +102,8 @@ class PPO2Agent(LearnerBase):
episode_rewards[-1] += rewards[0]
if dones[0] or current_step > nb_max_episode_steps:
obs = self.env.reset()
print ("Episode ", current_episode, "reward: ", episode_rewards[-1])
print("Episode ", current_episode, "reward: ",
episode_rewards[-1])
episode_rewards.append(0.0)
current_episode += 1
current_step = 0
......
......@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase):
setattr(self, prop, kwargs.get(prop, default))
def can_transition(self):
"""Returns boolean signifying whether we can transition. To be
implemented in subclass.
"""Returns boolean signifying whether we can transition.
To be implemented in subclass.
"""
raise NotImplemented(self.__class__.__name__ + \
"can_transition is not implemented.")
def do_transition(self, observation):
"""Do a transition, assuming we can transition. To be
implemented in subclass.
"""Do a transition, assuming we can transition. To be implemented in
subclass.
Args:
observation: final observation from episodic step
......@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase):
"do_transition is not implemented.")
def set_current_node(self, node_alias):
"""Sets the current node which is being executed
"""Sets the current node which is being executed.
Args:
node: node alias of the node to be set
......@@ -50,11 +51,13 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info
'''
def step_current_node(self, visualize_low_level_steps=False):
total_reward = 0
self.node_terminal_state_reached = False
while not self.node_terminal_state_reached:
observation, reward, terminal, info = self.low_level_step_current_node()
observation, reward, terminal, info = self.low_level_step_current_node(
)
if visualize_low_level_steps:
self.env.render()
total_reward += reward
......@@ -70,9 +73,12 @@ class ControllerBase(PolicyBase):
Returns state after one step, step reward, episode_termination_flag, info
'''
def low_level_step_current_node(self):
u_ego = self.current_node.low_level_policy(self.current_node.get_reduced_features_tuple())
u_ego = self.current_node.low_level_policy(
self.current_node.get_reduced_features_tuple())
feature, R, terminal, info = self.current_node.step(u_ego)
self.node_terminal_state_reached = terminal
return self.env.get_features_tuple(), R, self.env.termination_condition, info
return self.env.get_features_tuple(
), R, self.env.termination_condition, info
......@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
"oup_mu": 0, # OrnsteinUhlenbeckProcess mu
"oup_sigma": 1, # OrnsteinUhlenbeckProcess sigma
"oup_sigma_min": 0.5, # OrnsteinUhlenbeckProcess sigma min
"oup_annealing_steps": 500000, # OrnsteinUhlenbeckProcess n-step annealing
"oup_annealing_steps":
500000, # OrnsteinUhlenbeckProcess n-step annealing
"nb_steps_warmup_critic": 100, # steps for critic to warmup
"nb_steps_warmup_actor": 100, # steps for actor to warmup
"target_model_update": 1e-3 # target model update frequency
......@@ -160,7 +161,12 @@ class DDPGLearner(LearnerBase):
target_model_update=1e-3)
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
agent.compile([Adam(lr=self.lr*1e-2, clipnorm=1.), Adam(lr=self.lr, clipnorm=1.)], metrics=['mae'])
agent.compile(
[
Adam(lr=self.lr * 1e-2, clipnorm=1.),
Adam(lr=self.lr, clipnorm=1.)
],
metrics=['mae'])
return agent
......@@ -177,7 +183,11 @@ class DDPGLearner(LearnerBase):
callbacks = []
if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)]
callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')]
......@@ -291,11 +301,15 @@ class DQNLearner(LearnerBase):
Returns:
KerasRL DQN object
"""
agent = DQNAgentOverOptions(model=model, low_level_policies=self.low_level_policies,
nb_actions=self.nb_actions, memory=memory,
nb_steps_warmup=self.nb_steps_warmup, target_model_update=self.target_model_update,
policy=policy, enable_dueling_network=True)
agent = DQNAgentOverOptions(
model=model,
low_level_policies=self.low_level_policies,
nb_actions=self.nb_actions,
memory=memory,
nb_steps_warmup=self.nb_steps_warmup,
target_model_update=self.target_model_update,
policy=policy,
enable_dueling_network=True)
agent.compile(Adam(lr=self.lr), metrics=['mae'])
......@@ -312,7 +326,11 @@ class DQNLearner(LearnerBase):
callbacks = []
if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)]
callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')]
......@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
nb_episodes=5,
visualize=True,
nb_max_episode_steps=400,
success_reward_threshold = 100):
success_reward_threshold=100):
print("Testing for {} episodes".format(nb_episodes))
success_count = 0
......@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
env.reset()
if episode_reward >= success_reward_threshold:
success_count += 1
print("Episode {}: steps:{}, reward:{}".format(n+1, step, episode_reward))
print("Episode {}: steps:{}, reward:{}".format(
n + 1, step, episode_reward))
print ("\nPolicy succeeded {} times!".format(success_count))
print ("Failures due to:")
print (termination_reason_counter)
print("\nPolicy succeeded {} times!".format(success_count))
print("Failures due to:")
print(termination_reason_counter)
return [success_count,termination_reason_counter]
return [success_count, termination_reason_counter]
def load_model(self, file_name="test_weights.h5f"):
self.agent_model.load_weights(file_name)
......@@ -377,31 +396,43 @@ class DQNLearner(LearnerBase):
return self.agent_model.get_modified_q_values(observation)[action]
def get_q_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias)
action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
return self.agent_model.get_modified_q_values(observation)[action_num]
def get_softq_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias)
action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
q_values = self.agent_model.get_modified_q_values(observation)
max_q_value = np.abs(np.max(q_values))
q_values = [np.exp(q_value/max_q_value) for q_value in q_values]
relevant = q_values[action_num]/np.sum(q_values)
q_values = [np.exp(q_value / max_q_value) for q_value in q_values]
relevant = q_values[action_num] / np.sum(q_values)
return relevant
class DQNAgentOverOptions(DQNAgent):
def __init__(self, model, low_level_policies, policy=None, test_policy=None, enable_double_dqn=True, enable_dueling_network=False,
dueling_type='avg', *args, **kwargs):
super(DQNAgentOverOptions, self).__init__(model, policy, test_policy, enable_double_dqn, enable_dueling_network,
dueling_type, *args, **kwargs)
class DQNAgentOverOptions(DQNAgent):
def __init__(self,
model,
low_level_policies,
policy=None,
test_policy=None,
enable_double_dqn=True,
enable_dueling_network=False,
dueling_type='avg',
*args,
**kwargs):
super(DQNAgentOverOptions, self).__init__(
model, policy, test_policy, enable_double_dqn,
enable_dueling_network, dueling_type, *args, **kwargs)
self.low_level_policies = low_level_policies
if low_level_policies is not None:
self.low_level_policy_aliases = list(self.low_level_policies.keys())
self.low_level_policy_aliases = list(
self.low_level_policies.keys())
def __get_invalid_node_indices(self):
"""Returns a list of option indices that are invalid according to initiation conditions.
"""
"""Returns a list of option indices that are invalid according to
initiation conditions."""
invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases):
self.low_level_policies[option_alias].reset_maneuver()
......@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
q_values[node_index] = -np.inf
return q_values
from.policy_base import PolicyBase
from .policy_base import PolicyBase
import numpy as np
......
from .controller_base import ControllerBase
class ManualPolicy(ControllerBase):
"""Manual policy execution using nodes and edges."""
def __init__(self, env, low_level_policies, transition_adj, start_node_alias):
def __init__(self, env, low_level_policies, transition_adj,
start_node_alias):
"""Constructor for manual policy execution.
Args:
......@@ -13,12 +15,13 @@ class ManualPolicy(ControllerBase):
start_node: starting node
"""
super(ManualPolicy, self).__init__(env, low_level_policies, start_node_alias)
super(ManualPolicy, self).__init__(env, low_level_policies,
start_node_alias)
self.adj = transition_adj
def _transition(self):
"""Check if the current node's termination condition is met and if
it is possible to transition to another node, i.e. its initiation
"""Check if the current node's termination condition is met and if it
is possible to transition to another node, i.e. its initiation
condition is met. This is an internal function.
Returns the new node if a transition can happen, None otherwise
......
......@@ -2,15 +2,15 @@ from .controller_base import ControllerBase
import numpy as np
import pickle
class MCTSLearner(ControllerBase):
"""Monte Carlo Tree Search implementation using the UCB1 and
progressive widening approach as explained in Paxton et al (2017).
"""
"""Monte Carlo Tree Search implementation using the UCB1 and progressive
widening approach as explained in Paxton et al (2017)."""
_ucb_vals = set()
def __init__(self, env, low_level_policies,
start_node_alias, max_depth=10):
def __init__(self, env, low_level_policies, start_node_alias,
max_depth=10):
"""Constructor for MCTSLearner.
Args:
......@@ -22,10 +22,12 @@ class MCTSLearner(ControllerBase):
max_depth: max depth of the MCTS tree; default 10 levels
"""
super(MCTSLearner, self).__init__(env, low_level_policies, start_node_alias)
super(MCTSLearner, self).__init__(env, low_level_policies,
start_node_alias)
self.controller_args_defaults = {
"predictor": None #P(s, o) learner class; forward pass should return the entire value from state s and option o
"predictor":
None #P(s, o) learner class; forward pass should return the entire value from state s and option o
}
self.max_depth = max_depth
#: store current node alias
......@@ -51,8 +53,14 @@ class MCTSLearner(ControllerBase):
self.adj[root_node_num] = set() # no children
def save_model(self, file_name="mcts.pickle"):
to_backup = {'N': self.N, 'M': self.M, 'TR': self.TR, 'nodes': self.nodes,
'adj': self.adj, 'new_node_num': self.new_node_num}
to_backup = {
'N': self.N,
'M': self.M,
'TR': self.TR,
'nodes': self.nodes,
'adj': self.adj,
'new_node_num': self.new_node_num
}
with open(file_name, 'wb') as handle:
pickle.dump(to_backup, handle, protocol=pickle.HIGHEST_PROTOCOL)
......@@ -67,8 +75,8 @@ class MCTSLearner(ControllerBase):
self.new_node_num = to_restore['new_node_num']
def _create_node(self, low_level_policy):
"""Create the node associated with curr_node_num, using the
given low level policy.
"""Create the node associated with curr_node_num, using the given low
level policy.
Args:
low_level_policy: the option's alias
......@@ -83,11 +91,10 @@ class MCTSLearner(ControllerBase):
return created_node_num, {"policy": low_level_policy}
def _to_discrete(self, observation):
"""Converts observation to a discrete observation tuple. Also
append (a) whether we are following a vehicle, and (b) whether
there is a vehicle in the opposite lane in the approximately
the same x position. These values will be useful for Follow
and ChangeLane maneuvers.
"""Converts observation to a discrete observation tuple. Also append
(a) whether we are following a vehicle, and (b) whether there is a
vehicle in the opposite lane in the approximately the same x position.
These values will be useful for Follow and ChangeLane maneuvers.
Args:
observation: observation tuple from the environment
......@@ -96,9 +103,9 @@ class MCTSLearner(ControllerBase):
"""
dis_observation = ''
for item in observation[12:20]:
if type(item)==bool:
if type(item) == bool:
dis_observation += '1' if item is True else '0'
if type(item)==int and item in [0, 1]:
if type(item) == int and item in [0, 1]:
dis_observation += str(item)
env = self.current_node.env
......@@ -128,9 +135,9 @@ class MCTSLearner(ControllerBase):
def _get_visitation_count(self, observation, option=None):
"""Finds the visitation count of the discrete form of the observation.
If discrete observation not found, then inserted into self.N with
value 0. Auto converts the observation into discrete form. If option
is not None, then this uses self.M instead of self.N
If discrete observation not found, then inserted into self.N with value
0. Auto converts the observation into discrete form. If option is not
None, then this uses self.M instead of self.N.
Args:
observation: observation tuple from the environment
......@@ -163,12 +170,14 @@ class MCTSLearner(ControllerBase):
dis_observation = self._to_discrete(observation)
if (dis_observation, option) not in self.TR:
self.TR[(dis_observation, option)] = 0
return self.TR[(dis_observation, option)] / (1+self._get_visitation_count(observation, option))
return self.TR[(dis_observation, option)] / (
1 + self._get_visitation_count(observation, option))
def _get_possible_options(self):
"""Returns a set of options that can be taken from the current node.
Goes through adjacency set of current node and finds which next nodes'
initiation condition is met.
Goes through adjacency set of current node and finds which next
nodes' initiation condition is met.
"""
all_options = set(self.low_level_policies.keys())
......@@ -201,9 +210,9 @@ class MCTSLearner(ControllerBase):
return visited_aliases
def _ucb_adjusted_q(self, observation, C=1):
"""Computes Q_star(observation, option_i) plus the UCB term, which
is C*[predictor(observation, option_i)]/[1+N(observation, option_i)],
for all option_i in the adjacency set of the current node.
"""Computes Q_star(observation, option_i) plus the UCB term, which is
C*[predictor(observation, option_i)]/[1+N(observation, option_i)], for
all option_i in the adjacency set of the current node.
Args:
observation: observation tuple from the environment
......@@ -218,7 +227,8 @@ class MCTSLearner(ControllerBase):
next_option_nums = self.adj[self.curr_node_num]
for next_option_num in next_option_nums:
next_option = self.nodes[next_option_num]["policy"]
Q1[next_option] = (self._get_q_star(observation, next_option)+200)/400
Q1[next_option] = (
self._get_q_star(observation, next_option) + 200) / 400
Q[(dis_observation, next_option)] = \
Q1[next_option]
Q2[next_option] = C * \
......@@ -243,11 +253,11 @@ class MCTSLearner(ControllerBase):
relevant_rewards = [value for key, value in self.TR.items() \
if key[0] == dis_observation]
sum_rewards = np.sum(relevant_rewards)
return sum_rewards / (1+self._get_visitation_count(observation))
return sum_rewards / (1 + self._get_visitation_count(observation))
def _select(self, observation, depth=0, visualize=False):
"""MCTS selection function. For representation, we only use
the discrete part of the observation.
"""MCTS selection function. For representation, we only use the
discrete part of the observation.
Args:
observation: observation tuple from the environment
......@@ -266,7 +276,8 @@ class MCTSLearner(ControllerBase):
if is_terminal or max_depth_reached:
# print('MCTS went %d nodes deep' % depth)
return self._value(observation), 0 # TODO: replace with final goal reward
return self._value(
observation), 0 # TODO: replace with final goal reward
Ns = self._get_visitation_count(observation)
Nchildren = len(self.adj[self.curr_node_num])
......@@ -289,22 +300,23 @@ class MCTSLearner(ControllerBase):
# Find o_star and do a transition, i.e. update curr_node
# Simulate / lookup; first change next
next_observation, episode_R, o_star = self.do_transition(observation,
visualize=visualize)
next_observation, episode_R, o_star = self.do_transition(
observation, visualize=visualize)
# Recursively select next node
remaining_v, all_ep_R = self._select(next_observation, depth+1, visualize=visualize)
remaining_v, all_ep_R = self._select(
next_observation, depth + 1, visualize=visualize)
# Update values
self.N[dis_observation] += 1
self.M[(dis_observation, o_star)] += 1
self.TR[(dis_observation, o_star)] += (episode_R + remaining_v)
return self._value(observation), all_ep_R+episode_R
return self._value(observation), all_ep_R + episode_R
def traverse(self, observation, visualize=False):
"""Do a complete traversal from root to leaf. Assumes the
environment is reset and we are at the root node.
"""Do a complete traversal from root to leaf. Assumes the environment
is reset and we are at the root node.
Args:
observation: observation from the environment
......@@ -316,8 +328,8 @@ class MCTSLearner(ControllerBase):
return self._select(observation, visualize=visualize)
def do_transition(self, observation, visualize=False):
"""Do a transition using UCB metric, with the latest observation
from the episodic step.
"""Do a transition using UCB metric, with the latest observation from
the episodic step.
Args:
observation: final observation from episodic step
......
......@@ -3,8 +3,9 @@ from .mcts_learner import MCTSLearner
import tqdm
import numpy as np
class OnlineMCTSController(ControllerBase):
"""Online MCTS"""
"""Online MCTS."""
def __init__(self, env, low_level_policies, start_node_alias):
"""Constructor for manual policy execution.
......@@ -13,7 +14,8 @@ class OnlineMCTSController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super(OnlineMCTSController, self).__init__(env, low_level_policies, start_node_alias)
super(OnlineMCTSController, self).__init__(env, low_level_policies,
start_node_alias)
self.curr_node_alias = start_node_alias
self.controller_args_defaults = {
"predictor": None,
......@@ -48,12 +50,14 @@ class OnlineMCTSController(ControllerBase):
env_before_mcts = orig_env.copy()
self.change_low_level_references(env_before_mcts)
print('Current Node: %s' % self.curr_node_alias)
mcts = MCTSLearner(self.env, self.low_level_policies, self.curr_node_alias)
mcts = MCTSLearner(self.env, self.low_level_policies,
self.curr_node_alias)
mcts.max_depth = self.max_depth
mcts.set_controller_args(predictor = self.predictor)
mcts.set_controller_args(predictor=self.predictor)
# Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for num_epoch in range(self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
for num_epoch in range(
self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
mcts.curr_node_num = 0
env_begin_epoch = env_before_mcts.copy()
self.change_low_level_references(env_begin_epoch)
......@@ -63,6 +67,7 @@ class OnlineMCTSController(ControllerBase):
# Find the nodes from the root node
mcts.curr_node_num = 0
print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = mcts.get_best_node(self.env.get_features_tuple(), use_ucb=False)
node_after_transition = mcts.get_best_node(
self.env.get_features_tuple(), use_ucb=False)
print('MCTS suggested next option: %s' % node_after_transition)
self.set_current_node(node_after_transition)
class PolicyBase:
"""Abstract policy base from which every policy backend is defined
and inherited."""
\ No newline at end of file
"""Abstract policy base from which every policy backend is defined and
inherited."""
......@@ -11,7 +11,8 @@ class RLController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super(RLController, self).__init__(env, low_level_policies, start_node_alias)
super(RLController, self).__init__(env, low_level_policies,
start_node_alias)
self.low_level_policy_aliases = list(self.low_level_policies.keys())
self.trained_policy = None
self.node_terminal_state_reached = False
......@@ -32,6 +33,8 @@ class RLController(ControllerBase):
if self.trained_policy is None:
raise Exception(self.__class__.__name__ + \
"trained_policy is not set. Use set_trained_policy().")
node_index_after_transition = self.trained_policy(self.env.get_features_tuple())
self.set_current_node(self.low_level_policy_aliases[node_index_after_transition])
node_index_after_transition = self.trained_policy(
self.env.get_features_tuple())
self.set_current_node(
self.low_level_policy_aliases[node_index_after_transition])
self.node_terminal_state_reached = False
......@@ -3,22 +3,22 @@ from model_checker import Parser
class GymCompliantEnvBase:
def step(self, action):
""" Gym compliant step function which
will be implemented in the subclass.
"""
raise NotImplemented(self.__class__.__name__ + "step is not implemented.")
"""Gym compliant step function which will be implemented in the
subclass."""
raise NotImplemented(self.__class__.__name__ +
"step is not implemented.")
def reset(self):
""" Gym compliant reset function which
will be implemented in the subclass.
"""
raise NotImplemented(self.__class__.__name__ + "reset is not implemented.")
"""Gym compliant reset function which will be implemented in the
subclass."""
raise NotImplemented(self.__class__.__name__ +
"reset is not implemented.")
def render(self):
""" Gym compliant step function which
will be implemented in the subclass.
"""
raise NotImplemented(self.__class__.__name__ + "render is not implemented.")
"""Gym compliant step function which will be implemented in the
subclass."""
raise NotImplemented(self.__class__.__name__ +
"render is not implemented.")
class EpisodicEnvBase(GymCompliantEnvBase):
......@@ -46,6 +46,7 @@ class EpisodicEnvBase(GymCompliantEnvBase):
def _init_LTL_preconditions(self):
"""Initialize the LTL preconditions (self._LTL_preconditions)..
in the subclass.
"""
return
......@@ -69,21 +70,23 @@ class EpisodicEnvBase(GymCompliantEnvBase):
self._r_terminal = None
def _terminal_reward_superposition(self, r_obs):
"""Calculate the next value when observing "obs," from the prior
using min, max, or summation depending on the superposition_type.
"""
"""Calculate the next value when observing "obs," from the prior using
min, max, or summation depending on the superposition_type."""
if r_obs is None:
return
if self.terminal_reward_type == 'min':
self._r_terminal = r_obs if self._r_terminal is None else min(self._r_terminal, r_obs)
self._r_terminal = r_obs if self._r_terminal is None else min(
self._r_terminal, r_obs)
elif self.terminal_reward_type == 'max':
self._r_terminal = r_obs if self._r_terminal is None else max(self._r_terminal, r_obs)
self._r_terminal = r_obs if self._r_terminal is None else max(
self._r_terminal, r_obs)
elif self.terminal_reward_type == 'sum':
self._r_terminal = r_obs if self._r_terminal is None else self._r_terminal + r_obs
else:
raise AssertionError("The terminal_reward_type has to be 'min', 'max', or 'sum'")
raise AssertionError(
"The terminal_reward_type has to be 'min', 'max', or 'sum'")
def step(self, u):
# the penalty is a negative reward.
......@@ -118,7 +121,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
if LTL_precondition.enabled:
LTL_precondition.check_incremental(self.__mc_AP)
if LTL_precondition.result == Parser.FALSE:
self._terminal_reward_superposition(EpisodicEnvBase._reward(LTL_precondition.penalty))
self._terminal_reward_superposition(
EpisodicEnvBase._reward(LTL_precondition.penalty))
violate = True
info['ltl_violation'] = LTL_precondition.str
# print("\nViolation of \"" + LTL_precondition.str + "\"")
......@@ -126,7 +130,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return violate, info
def current_model_checking_result(self):
"""Returns whether or not any of the conditions is currently violated."""
"""Returns whether or not any of the conditions is currently
violated."""
for LTL_precondition in self._LTL_preconditions:
if LTL_precondition.result == Parser.FALSE:
......@@ -148,9 +153,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
@property
def termination_condition(self):
"""In the subclass, specify the condition for termination of the episode
(or the maneuver).
"""
"""In the subclass, specify the condition for termination of the
episode (or the maneuver)."""
if self._terminate_in_goal and self.goal_achieved:
return True
......@@ -177,7 +181,6 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return True
return False
# TODO: replace these confusing methods reward and penalty, or not to use both reward and penalty for the property naming.
@staticmethod
def _reward(penalty):
......@@ -186,4 +189,3 @@ class EpisodicEnvBase(GymCompliantEnvBase):
@staticmethod
def _penalty(reward):
return None if reward is None else -reward
class RoadEnv:
"""The generic road env
"""The generic road env.
TODO: Implement this generic road env for plugging-in other road envs.
TODO: roadEnv also having a step() function can cause a problem.
......
......@@ -69,7 +69,8 @@ ego_feature_dict = dict()
for key in con_ego_feature_dict.keys():
ego_feature_dict[key] = con_ego_feature_dict[key]
for key in dis_ego_feature_dict.keys():
ego_feature_dict[key] = dis_ego_feature_dict[key] + len(con_ego_feature_dict)
ego_feature_dict[