Commit f2171d2c authored by Aravind Balakrishnan's avatar Aravind Balakrishnan

Merge branch 'formatting' into 'master'

Formatting

See merge request !3
parents e1fdb162 7abc600f
*.pyc
*.py~
__pycache__
...@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase): ...@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
self.log_path = log_path self.log_path = log_path
self.env = DummyVecEnv([lambda: env]) #PPO2 requried a vectorized environment for parallel training self.env = DummyVecEnv([
lambda: env
]) #PPO2 requried a vectorized environment for parallel training
self.agent_model = self.create_agent(policy, tensorboard) self.agent_model = self.create_agent(policy, tensorboard)
def get_default_policy(self): def get_default_policy(self):
...@@ -46,13 +48,13 @@ class PPO2Agent(LearnerBase): ...@@ -46,13 +48,13 @@ class PPO2Agent(LearnerBase):
return MlpPolicy return MlpPolicy
def create_agent(self, policy, tensorboard): def create_agent(self, policy, tensorboard):
"""Creates a PPO agent """Creates a PPO agent.
Returns: Returns: stable_baselines PPO2 object
stable_baselines PPO2 object
""" """
if tensorboard: if tensorboard:
return PPO2(policy, self.env, verbose=1, tensorboard_log=self.log_path) return PPO2(
policy, self.env, verbose=1, tensorboard_log=self.log_path)
else: else:
return PPO2(policy, self.env, verbose=1) return PPO2(policy, self.env, verbose=1)
...@@ -100,7 +102,8 @@ class PPO2Agent(LearnerBase): ...@@ -100,7 +102,8 @@ class PPO2Agent(LearnerBase):
episode_rewards[-1] += rewards[0] episode_rewards[-1] += rewards[0]
if dones[0] or current_step > nb_max_episode_steps: if dones[0] or current_step > nb_max_episode_steps:
obs = self.env.reset() obs = self.env.reset()
print ("Episode ", current_episode, "reward: ", episode_rewards[-1]) print("Episode ", current_episode, "reward: ",
episode_rewards[-1])
episode_rewards.append(0.0) episode_rewards.append(0.0)
current_episode += 1 current_episode += 1
current_step = 0 current_step = 0
......
...@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase): ...@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase):
setattr(self, prop, kwargs.get(prop, default)) setattr(self, prop, kwargs.get(prop, default))
def can_transition(self): def can_transition(self):
"""Returns boolean signifying whether we can transition. To be """Returns boolean signifying whether we can transition.
implemented in subclass.
To be implemented in subclass.
""" """
raise NotImplemented(self.__class__.__name__ + \ raise NotImplemented(self.__class__.__name__ + \
"can_transition is not implemented.") "can_transition is not implemented.")
def do_transition(self, observation): def do_transition(self, observation):
"""Do a transition, assuming we can transition. To be """Do a transition, assuming we can transition. To be implemented in
implemented in subclass. subclass.
Args: Args:
observation: final observation from episodic step observation: final observation from episodic step
...@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase): ...@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase):
"do_transition is not implemented.") "do_transition is not implemented.")
def set_current_node(self, node_alias): def set_current_node(self, node_alias):
"""Sets the current node which is being executed """Sets the current node which is being executed.
Args: Args:
node: node alias of the node to be set node: node alias of the node to be set
...@@ -50,11 +51,13 @@ class ControllerBase(PolicyBase): ...@@ -50,11 +51,13 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info Returns state at end of node execution, total reward, epsiode_termination_flag, info
''' '''
def step_current_node(self, visualize_low_level_steps=False): def step_current_node(self, visualize_low_level_steps=False):
total_reward = 0 total_reward = 0
self.node_terminal_state_reached = False self.node_terminal_state_reached = False
while not self.node_terminal_state_reached: while not self.node_terminal_state_reached:
observation, reward, terminal, info = self.low_level_step_current_node() observation, reward, terminal, info = self.low_level_step_current_node(
)
if visualize_low_level_steps: if visualize_low_level_steps:
self.env.render() self.env.render()
total_reward += reward total_reward += reward
...@@ -70,9 +73,12 @@ class ControllerBase(PolicyBase): ...@@ -70,9 +73,12 @@ class ControllerBase(PolicyBase):
Returns state after one step, step reward, episode_termination_flag, info Returns state after one step, step reward, episode_termination_flag, info
''' '''
def low_level_step_current_node(self): def low_level_step_current_node(self):
u_ego = self.current_node.low_level_policy(self.current_node.get_reduced_features_tuple()) u_ego = self.current_node.low_level_policy(
self.current_node.get_reduced_features_tuple())
feature, R, terminal, info = self.current_node.step(u_ego) feature, R, terminal, info = self.current_node.step(u_ego)
self.node_terminal_state_reached = terminal self.node_terminal_state_reached = terminal
return self.env.get_features_tuple(), R, self.env.termination_condition, info return self.env.get_features_tuple(
), R, self.env.termination_condition, info
...@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase): ...@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
"oup_mu": 0, # OrnsteinUhlenbeckProcess mu "oup_mu": 0, # OrnsteinUhlenbeckProcess mu
"oup_sigma": 1, # OrnsteinUhlenbeckProcess sigma "oup_sigma": 1, # OrnsteinUhlenbeckProcess sigma
"oup_sigma_min": 0.5, # OrnsteinUhlenbeckProcess sigma min "oup_sigma_min": 0.5, # OrnsteinUhlenbeckProcess sigma min
"oup_annealing_steps": 500000, # OrnsteinUhlenbeckProcess n-step annealing "oup_annealing_steps":
500000, # OrnsteinUhlenbeckProcess n-step annealing
"nb_steps_warmup_critic": 100, # steps for critic to warmup "nb_steps_warmup_critic": 100, # steps for critic to warmup
"nb_steps_warmup_actor": 100, # steps for actor to warmup "nb_steps_warmup_actor": 100, # steps for actor to warmup
"target_model_update": 1e-3 # target model update frequency "target_model_update": 1e-3 # target model update frequency
...@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase): ...@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase):
target_model_update=1e-3) target_model_update=1e-3)
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic. # TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
agent.compile([Adam(lr=self.lr*1e-2, clipnorm=1.), Adam(lr=self.lr, clipnorm=1.)], metrics=['mae']) agent.compile(
[
Adam(lr=self.lr * 1e-2, clipnorm=1.),
Adam(lr=self.lr, clipnorm=1.)
],
metrics=['mae'])
return agent return agent
def train(self, def train(self,
env, env,
nb_steps=1000000, nb_steps=1000000,
visualize=False, visualize=False,
verbose=1, verbose=1,
log_interval=10000, log_interval=10000,
nb_max_episode_steps=200, nb_max_episode_steps=200,
model_checkpoints=False, model_checkpoints=False,
checkpoint_interval=100000, checkpoint_interval=100000,
tensorboard=False): tensorboard=False):
callbacks = [] callbacks = []
if model_checkpoints: if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)] callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard: if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')] callbacks += [TensorBoard(log_dir='./logs')]
...@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase): ...@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase):
Returns: Returns:
KerasRL DQN object KerasRL DQN object
""" """
agent = DQNAgentOverOptions(model=model, low_level_policies=self.low_level_policies, agent = DQNAgentOverOptions(
nb_actions=self.nb_actions, memory=memory, model=model,
nb_steps_warmup=self.nb_steps_warmup, target_model_update=self.target_model_update, low_level_policies=self.low_level_policies,
policy=policy, enable_dueling_network=True) nb_actions=self.nb_actions,
memory=memory,
nb_steps_warmup=self.nb_steps_warmup,
target_model_update=self.target_model_update,
policy=policy,
enable_dueling_network=True)
agent.compile(Adam(lr=self.lr), metrics=['mae']) agent.compile(Adam(lr=self.lr), metrics=['mae'])
return agent return agent
def train(self, def train(self,
env, env,
nb_steps=1000000, nb_steps=1000000,
visualize=False, visualize=False,
nb_max_episode_steps=200, nb_max_episode_steps=200,
tensorboard=False, tensorboard=False,
model_checkpoints=False, model_checkpoints=False,
checkpoint_interval=10000): checkpoint_interval=10000):
callbacks = [] callbacks = []
if model_checkpoints: if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)] callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard: if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')] callbacks += [TensorBoard(log_dir='./logs')]
...@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase): ...@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
nb_episodes=5, nb_episodes=5,
visualize=True, visualize=True,
nb_max_episode_steps=400, nb_max_episode_steps=400,
success_reward_threshold = 100): success_reward_threshold=100):
print("Testing for {} episodes".format(nb_episodes)) print("Testing for {} episodes".format(nb_episodes))
success_count = 0 success_count = 0
...@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase): ...@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
env.reset() env.reset()
if episode_reward >= success_reward_threshold: if episode_reward >= success_reward_threshold:
success_count += 1 success_count += 1
print("Episode {}: steps:{}, reward:{}".format(n+1, step, episode_reward)) print("Episode {}: steps:{}, reward:{}".format(
n + 1, step, episode_reward))
print ("\nPolicy succeeded {} times!".format(success_count)) print("\nPolicy succeeded {} times!".format(success_count))
print ("Failures due to:") print("Failures due to:")
print (termination_reason_counter) print(termination_reason_counter)
return [success_count,termination_reason_counter] return [success_count, termination_reason_counter]
def load_model(self, file_name="test_weights.h5f"): def load_model(self, file_name="test_weights.h5f"):
self.agent_model.load_weights(file_name) self.agent_model.load_weights(file_name)
...@@ -377,31 +396,43 @@ class DQNLearner(LearnerBase): ...@@ -377,31 +396,43 @@ class DQNLearner(LearnerBase):
return self.agent_model.get_modified_q_values(observation)[action] return self.agent_model.get_modified_q_values(observation)[action]
def get_q_value_using_option_alias(self, observation, option_alias): def get_q_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias) action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
return self.agent_model.get_modified_q_values(observation)[action_num] return self.agent_model.get_modified_q_values(observation)[action_num]
def get_softq_value_using_option_alias(self, observation, option_alias): def get_softq_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias) action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
q_values = self.agent_model.get_modified_q_values(observation) q_values = self.agent_model.get_modified_q_values(observation)
max_q_value = np.abs(np.max(q_values)) max_q_value = np.abs(np.max(q_values))
q_values = [np.exp(q_value/max_q_value) for q_value in q_values] q_values = [np.exp(q_value / max_q_value) for q_value in q_values]
relevant = q_values[action_num]/np.sum(q_values) relevant = q_values[action_num] / np.sum(q_values)
return relevant return relevant
class DQNAgentOverOptions(DQNAgent):
def __init__(self, model, low_level_policies, policy=None, test_policy=None, enable_double_dqn=True, enable_dueling_network=False, class DQNAgentOverOptions(DQNAgent):
dueling_type='avg', *args, **kwargs): def __init__(self,
super(DQNAgentOverOptions, self).__init__(model, policy, test_policy, enable_double_dqn, enable_dueling_network, model,
dueling_type, *args, **kwargs) low_level_policies,
policy=None,
test_policy=None,
enable_double_dqn=True,
enable_dueling_network=False,
dueling_type='avg',
*args,
**kwargs):
super(DQNAgentOverOptions, self).__init__(
model, policy, test_policy, enable_double_dqn,
enable_dueling_network, dueling_type, *args, **kwargs)
self.low_level_policies = low_level_policies self.low_level_policies = low_level_policies
if low_level_policies is not None: if low_level_policies is not None:
self.low_level_policy_aliases = list(self.low_level_policies.keys()) self.low_level_policy_aliases = list(
self.low_level_policies.keys())
def __get_invalid_node_indices(self): def __get_invalid_node_indices(self):
"""Returns a list of option indices that are invalid according to initiation conditions. """Returns a list of option indices that are invalid according to
""" initiation conditions."""
invalid_node_indices = list() invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases): for index, option_alias in enumerate(self.low_level_policy_aliases):
self.low_level_policies[option_alias].reset_maneuver() self.low_level_policies[option_alias].reset_maneuver()
...@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent): ...@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
q_values[node_index] = -np.inf q_values[node_index] = -np.inf
return q_values return q_values
from.policy_base import PolicyBase from .policy_base import PolicyBase
import numpy as np import numpy as np
...@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase): ...@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase):
setattr(self, prop, kwargs.get(prop, default)) setattr(self, prop, kwargs.get(prop, default))
def train(self, def train(self,
env, env,
nb_steps=50000, nb_steps=50000,
visualize=False, visualize=False,
nb_max_episode_steps=200): nb_max_episode_steps=200):
"""Train the learning agent on the environment. """Train the learning agent on the environment.
Args: Args:
......
from .controller_base import ControllerBase from .controller_base import ControllerBase
class ManualPolicy(ControllerBase): class ManualPolicy(ControllerBase):
"""Manual policy execution using nodes and edges.""" """Manual policy execution using nodes and edges."""
def __init__(self, env, low_level_policies, transition_adj, start_node_alias): def __init__(self, env, low_level_policies, transition_adj,
start_node_alias):
"""Constructor for manual policy execution. """Constructor for manual policy execution.
Args: Args:
...@@ -13,12 +15,13 @@ class ManualPolicy(ControllerBase): ...@@ -13,12 +15,13 @@ class ManualPolicy(ControllerBase):
start_node: starting node start_node: starting node
""" """
super(ManualPolicy, self).__init__(env, low_level_policies, start_node_alias) super(ManualPolicy, self).__init__(env, low_level_policies,
start_node_alias)
self.adj = transition_adj self.adj = transition_adj
def _transition(self): def _transition(self):
"""Check if the current node's termination condition is met and if """Check if the current node's termination condition is met and if it
it is possible to transition to another node, i.e. its initiation is possible to transition to another node, i.e. its initiation
condition is met. This is an internal function. condition is met. This is an internal function.
Returns the new node if a transition can happen, None otherwise Returns the new node if a transition can happen, None otherwise
...@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase): ...@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase):
new_node = self._transition() new_node = self._transition()
if new_node is not None: if new_node is not None:
self.current_node = new_node self.current_node = new_node
\ No newline at end of file
This diff is collapsed.
...@@ -3,8 +3,9 @@ from .mcts_learner import MCTSLearner ...@@ -3,8 +3,9 @@ from .mcts_learner import MCTSLearner
import tqdm import tqdm
import numpy as np import numpy as np
class OnlineMCTSController(ControllerBase): class OnlineMCTSController(ControllerBase):
"""Online MCTS""" """Online MCTS."""
def __init__(self, env, low_level_policies, start_node_alias): def __init__(self, env, low_level_policies, start_node_alias):
"""Constructor for manual policy execution. """Constructor for manual policy execution.
...@@ -13,12 +14,13 @@ class OnlineMCTSController(ControllerBase): ...@@ -13,12 +14,13 @@ class OnlineMCTSController(ControllerBase):
env: env instance env: env instance
low_level_policies: low level policies dictionary low_level_policies: low level policies dictionary
""" """
super(OnlineMCTSController, self).__init__(env, low_level_policies, start_node_alias) super(OnlineMCTSController, self).__init__(env, low_level_policies,
start_node_alias)
self.curr_node_alias = start_node_alias self.curr_node_alias = start_node_alias
self.controller_args_defaults = { self.controller_args_defaults = {
"predictor": None, "predictor": None,
"max_depth": 5, # MCTS depth "max_depth": 5, # MCTS depth
"nb_traversals": 30, # MCTS traversals before decision "nb_traversals": 30, # MCTS traversals before decision
} }
def set_current_node(self, node_alias): def set_current_node(self, node_alias):
...@@ -48,12 +50,14 @@ class OnlineMCTSController(ControllerBase): ...@@ -48,12 +50,14 @@ class OnlineMCTSController(ControllerBase):
env_before_mcts = orig_env.copy() env_before_mcts = orig_env.copy()
self.change_low_level_references(env_before_mcts) self.change_low_level_references(env_before_mcts)
print('Current Node: %s' % self.curr_node_alias) print('Current Node: %s' % self.curr_node_alias)
mcts = MCTSLearner(self.env, self.low_level_policies, self.curr_node_alias) mcts = MCTSLearner(self.env, self.low_level_policies,
self.curr_node_alias)
mcts.max_depth = self.max_depth mcts.max_depth = self.max_depth
mcts.set_controller_args(predictor = self.predictor) mcts.set_controller_args(predictor=self.predictor)
# Do nb_traversals number of traversals, reset env to this point every time # Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals)) # print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for num_epoch in range(self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)): for num_epoch in range(
self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
mcts.curr_node_num = 0 mcts.curr_node_num = 0
env_begin_epoch = env_before_mcts.copy() env_begin_epoch = env_before_mcts.copy()
self.change_low_level_references(env_begin_epoch) self.change_low_level_references(env_begin_epoch)
...@@ -63,6 +67,7 @@ class OnlineMCTSController(ControllerBase): ...@@ -63,6 +67,7 @@ class OnlineMCTSController(ControllerBase):
# Find the nodes from the root node # Find the nodes from the root node
mcts.curr_node_num = 0 mcts.curr_node_num = 0
print('%s' % mcts._to_discrete(self.env.get_features_tuple())) print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = mcts.get_best_node(self.env.get_features_tuple(), use_ucb=False) node_after_transition = mcts.get_best_node(
self.env.get_features_tuple(), use_ucb=False)
print('MCTS suggested next option: %s' % node_after_transition) print('MCTS suggested next option: %s' % node_after_transition)
self.set_current_node(node_after_transition) self.set_current_node(node_after_transition)
\ No newline at end of file
class PolicyBase: class PolicyBase:
"""Abstract policy base from which every policy backend is defined """Abstract policy base from which every policy backend is defined and
and inherited.""" inherited."""
\ No newline at end of file
...@@ -11,7 +11,8 @@ class RLController(ControllerBase): ...@@ -11,7 +11,8 @@ class RLController(ControllerBase):
env: env instance env: env instance
low_level_policies: low level policies dictionary low_level_policies: low level policies dictionary
""" """
super(RLController, self).__init__(env, low_level_policies, start_node_alias) super(RLController, self).__init__(env, low_level_policies,
start_node_alias)
self.low_level_policy_aliases = list(self.low_level_policies.keys()) self.low_level_policy_aliases = list(self.low_level_policies.keys())
self.trained_policy = None self.trained_policy = None
self.node_terminal_state_reached = False self.node_terminal_state_reached = False
...@@ -32,6 +33,8 @@ class RLController(ControllerBase): ...@@ -32,6 +33,8 @@ class RLController(ControllerBase):
if self.trained_policy is None: if self.trained_policy is None:
raise Exception(self.__class__.__name__ + \ raise Exception(self.__class__.__name__ + \
"trained_policy is not set. Use set_trained_policy().") "trained_policy is not set. Use set_trained_policy().")
node_index_after_transition = self.trained_policy(self.env.get_features_tuple()) node_index_after_transition = self.trained_policy(
self.set_current_node(self.low_level_policy_aliases[node_index_after_transition]) self.env.get_features_tuple())
self.node_terminal_state_reached = False self.set_current_node(
\ No newline at end of file self.low_level_policy_aliases[node_index_after_transition])
self.node_terminal_state_reached = False
...@@ -3,22 +3,22 @@ from model_checker import Parser ...@@ -3,22 +3,22 @@ from model_checker import Parser
class GymCompliantEnvBase: class GymCompliantEnvBase:
def step(self, action): def step(self, action):
""" Gym compliant step function which """Gym compliant step function which will be implemented in the
will be implemented in the subclass. subclass."""
""" raise NotImplemented(self.__class__.__name__ +
raise NotImplemented(self.__class__.__name__ + "step is not implemented.") "step is not implemented.")
def reset(self): def reset(self):
""" Gym compliant reset function which """Gym compliant reset function which will be implemented in the
will be implemented in the subclass. subclass."""
""" raise NotImplemented(self.__class__.__name__ +
raise NotImplemented(self.__class__.__name__ + "reset is not implemented.") "reset is not implemented.")
def render(self): def render(self):
""" Gym compliant step function which """Gym compliant step function which will be implemented in the
will be implemented in the subclass. subclass."""
""" raise NotImplemented(self.__class__.__name__ +
raise NotImplemented(self.__class__.__name__ + "render is not implemented.") "render is not implemented.")
class EpisodicEnvBase(GymCompliantEnvBase): class EpisodicEnvBase(GymCompliantEnvBase):
...@@ -46,7 +46,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -46,7 +46,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
def _init_LTL_preconditions(self): def _init_LTL_preconditions(self):
"""Initialize the LTL preconditions (self._LTL_preconditions).. """Initialize the LTL preconditions (self._LTL_preconditions)..
in the subclass.
in the subclass.
""" """
return return
...@@ -69,21 +70,23 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -69,21 +70,23 @@ class EpisodicEnvBase(GymCompliantEnvBase):
self._r_terminal = None self._r_terminal = None
def _terminal_reward_superposition(self, r_obs): def _terminal_reward_superposition(self, r_obs):
"""Calculate the next value when observing "obs," from the prior """Calculate the next value when observing "obs," from the prior using
using min, max, or summation depending on the superposition_type. min, max, or summation depending on the superposition_type."""
"""
if r_obs is None: if r_obs is None:
return return
if self.terminal_reward_type == 'min': if self.terminal_reward_type == 'min':
self._r_terminal = r_obs if self._r_terminal is None else min(self._r_terminal, r_obs) self._r_terminal = r_obs if self._r_terminal is None else min(
self._r_terminal, r_obs)
elif self.terminal_reward_type == 'max': elif self.terminal_reward_type == 'max':
self._r_terminal = r_obs if self._r_terminal is None else max(self._r_terminal, r_obs) self._r_terminal = r_obs if self._r_terminal is None else max(
self._r_terminal, r_obs)
elif self.terminal_reward_type == 'sum': elif self.terminal_reward_type == 'sum':
self._r_terminal = r_obs if self._r_terminal is None else self._r_terminal + r_obs self._r_terminal = r_obs if self._r_terminal is None else self._r_terminal + r_obs
else: else:
raise AssertionError("The terminal_reward_type has to be 'min', 'max', or 'sum'") raise AssertionError(
"The terminal_reward_type has to be 'min', 'max', or 'sum'")
def step(self, u): def step(self, u):
# the penalty is a negative reward. # the penalty is a negative reward.
...@@ -118,7 +121,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -118,7 +121,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
if LTL_precondition.enabled: if LTL_precondition.enabled:
LTL_precondition.check_incremental(self.__mc_AP) LTL_precondition.check_incremental(self.__mc_AP)
if LTL_precondition.result == Parser.FALSE: if LTL_precondition.result == Parser.FALSE:
self._terminal_reward_superposition(EpisodicEnvBase._reward(LTL_precondition.penalty)) self._terminal_reward_superposition(
EpisodicEnvBase._reward(LTL_precondition.penalty))
violate = True violate = True
info['ltl_violation'] = LTL_precondition.str info['ltl_violation'] = LTL_precondition.str
# print("\nViolation of \"" + LTL_precondition.str + "\"") # print("\nViolation of \"" + LTL_precondition.str + "\"")
...@@ -126,7 +130,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -126,7 +130,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return violate, info return violate, info
def current_model_checking_result(self): def current_model_checking_result(self):
"""Returns whether or not any of the conditions is currently violated.""" """Returns whether or not any of the conditions is currently
violated."""
for LTL_precondition in self._LTL_preconditions: for LTL_precondition in self._LTL_preconditions:
if LTL_precondition.result == Parser.FALSE: if LTL_precondition.result == Parser.FALSE:
...@@ -148,9 +153,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -148,9 +153,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):