Commit f2171d2c authored by Aravind Balakrishnan's avatar Aravind Balakrishnan

Merge branch 'formatting' into 'master'

Formatting

See merge request !3
parents e1fdb162 7abc600f
*.pyc
*.py~
__pycache__
...@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase): ...@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
self.log_path = log_path self.log_path = log_path
self.env = DummyVecEnv([lambda: env]) #PPO2 requried a vectorized environment for parallel training self.env = DummyVecEnv([
lambda: env
]) #PPO2 requried a vectorized environment for parallel training
self.agent_model = self.create_agent(policy, tensorboard) self.agent_model = self.create_agent(policy, tensorboard)
def get_default_policy(self): def get_default_policy(self):
...@@ -46,13 +48,13 @@ class PPO2Agent(LearnerBase): ...@@ -46,13 +48,13 @@ class PPO2Agent(LearnerBase):
return MlpPolicy return MlpPolicy
def create_agent(self, policy, tensorboard): def create_agent(self, policy, tensorboard):
"""Creates a PPO agent """Creates a PPO agent.
Returns: Returns: stable_baselines PPO2 object
stable_baselines PPO2 object
""" """
if tensorboard: if tensorboard:
return PPO2(policy, self.env, verbose=1, tensorboard_log=self.log_path) return PPO2(
policy, self.env, verbose=1, tensorboard_log=self.log_path)
else: else:
return PPO2(policy, self.env, verbose=1) return PPO2(policy, self.env, verbose=1)
...@@ -100,7 +102,8 @@ class PPO2Agent(LearnerBase): ...@@ -100,7 +102,8 @@ class PPO2Agent(LearnerBase):
episode_rewards[-1] += rewards[0] episode_rewards[-1] += rewards[0]
if dones[0] or current_step > nb_max_episode_steps: if dones[0] or current_step > nb_max_episode_steps:
obs = self.env.reset() obs = self.env.reset()
print ("Episode ", current_episode, "reward: ", episode_rewards[-1]) print("Episode ", current_episode, "reward: ",
episode_rewards[-1])
episode_rewards.append(0.0) episode_rewards.append(0.0)
current_episode += 1 current_episode += 1
current_step = 0 current_step = 0
......
...@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase): ...@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase):
setattr(self, prop, kwargs.get(prop, default)) setattr(self, prop, kwargs.get(prop, default))
def can_transition(self): def can_transition(self):
"""Returns boolean signifying whether we can transition. To be """Returns boolean signifying whether we can transition.
implemented in subclass.
To be implemented in subclass.
""" """
raise NotImplemented(self.__class__.__name__ + \ raise NotImplemented(self.__class__.__name__ + \
"can_transition is not implemented.") "can_transition is not implemented.")
def do_transition(self, observation): def do_transition(self, observation):
"""Do a transition, assuming we can transition. To be """Do a transition, assuming we can transition. To be implemented in
implemented in subclass. subclass.
Args: Args:
observation: final observation from episodic step observation: final observation from episodic step
...@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase): ...@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase):
"do_transition is not implemented.") "do_transition is not implemented.")
def set_current_node(self, node_alias): def set_current_node(self, node_alias):
"""Sets the current node which is being executed """Sets the current node which is being executed.
Args: Args:
node: node alias of the node to be set node: node alias of the node to be set
...@@ -50,11 +51,13 @@ class ControllerBase(PolicyBase): ...@@ -50,11 +51,13 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info Returns state at end of node execution, total reward, epsiode_termination_flag, info
''' '''
def step_current_node(self, visualize_low_level_steps=False): def step_current_node(self, visualize_low_level_steps=False):
total_reward = 0 total_reward = 0
self.node_terminal_state_reached = False self.node_terminal_state_reached = False
while not self.node_terminal_state_reached: while not self.node_terminal_state_reached:
observation, reward, terminal, info = self.low_level_step_current_node() observation, reward, terminal, info = self.low_level_step_current_node(
)
if visualize_low_level_steps: if visualize_low_level_steps:
self.env.render() self.env.render()
total_reward += reward total_reward += reward
...@@ -70,9 +73,12 @@ class ControllerBase(PolicyBase): ...@@ -70,9 +73,12 @@ class ControllerBase(PolicyBase):
Returns state after one step, step reward, episode_termination_flag, info Returns state after one step, step reward, episode_termination_flag, info
''' '''
def low_level_step_current_node(self): def low_level_step_current_node(self):
u_ego = self.current_node.low_level_policy(self.current_node.get_reduced_features_tuple()) u_ego = self.current_node.low_level_policy(
self.current_node.get_reduced_features_tuple())
feature, R, terminal, info = self.current_node.step(u_ego) feature, R, terminal, info = self.current_node.step(u_ego)
self.node_terminal_state_reached = terminal self.node_terminal_state_reached = terminal
return self.env.get_features_tuple(), R, self.env.termination_condition, info return self.env.get_features_tuple(
), R, self.env.termination_condition, info
...@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase): ...@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
"oup_mu": 0, # OrnsteinUhlenbeckProcess mu "oup_mu": 0, # OrnsteinUhlenbeckProcess mu
"oup_sigma": 1, # OrnsteinUhlenbeckProcess sigma "oup_sigma": 1, # OrnsteinUhlenbeckProcess sigma
"oup_sigma_min": 0.5, # OrnsteinUhlenbeckProcess sigma min "oup_sigma_min": 0.5, # OrnsteinUhlenbeckProcess sigma min
"oup_annealing_steps": 500000, # OrnsteinUhlenbeckProcess n-step annealing "oup_annealing_steps":
500000, # OrnsteinUhlenbeckProcess n-step annealing
"nb_steps_warmup_critic": 100, # steps for critic to warmup "nb_steps_warmup_critic": 100, # steps for critic to warmup
"nb_steps_warmup_actor": 100, # steps for actor to warmup "nb_steps_warmup_actor": 100, # steps for actor to warmup
"target_model_update": 1e-3 # target model update frequency "target_model_update": 1e-3 # target model update frequency
...@@ -160,7 +161,12 @@ class DDPGLearner(LearnerBase): ...@@ -160,7 +161,12 @@ class DDPGLearner(LearnerBase):
target_model_update=1e-3) target_model_update=1e-3)
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic. # TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
agent.compile([Adam(lr=self.lr*1e-2, clipnorm=1.), Adam(lr=self.lr, clipnorm=1.)], metrics=['mae']) agent.compile(
[
Adam(lr=self.lr * 1e-2, clipnorm=1.),
Adam(lr=self.lr, clipnorm=1.)
],
metrics=['mae'])
return agent return agent
...@@ -177,7 +183,11 @@ class DDPGLearner(LearnerBase): ...@@ -177,7 +183,11 @@ class DDPGLearner(LearnerBase):
callbacks = [] callbacks = []
if model_checkpoints: if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)] callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard: if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')] callbacks += [TensorBoard(log_dir='./logs')]
...@@ -291,11 +301,15 @@ class DQNLearner(LearnerBase): ...@@ -291,11 +301,15 @@ class DQNLearner(LearnerBase):
Returns: Returns:
KerasRL DQN object KerasRL DQN object
""" """
agent = DQNAgentOverOptions(model=model, low_level_policies=self.low_level_policies, agent = DQNAgentOverOptions(
nb_actions=self.nb_actions, memory=memory, model=model,
nb_steps_warmup=self.nb_steps_warmup, target_model_update=self.target_model_update, low_level_policies=self.low_level_policies,
policy=policy, enable_dueling_network=True) nb_actions=self.nb_actions,
memory=memory,
nb_steps_warmup=self.nb_steps_warmup,
target_model_update=self.target_model_update,
policy=policy,
enable_dueling_network=True)
agent.compile(Adam(lr=self.lr), metrics=['mae']) agent.compile(Adam(lr=self.lr), metrics=['mae'])
...@@ -312,7 +326,11 @@ class DQNLearner(LearnerBase): ...@@ -312,7 +326,11 @@ class DQNLearner(LearnerBase):
callbacks = [] callbacks = []
if model_checkpoints: if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)] callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard: if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')] callbacks += [TensorBoard(log_dir='./logs')]
...@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase): ...@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
nb_episodes=5, nb_episodes=5,
visualize=True, visualize=True,
nb_max_episode_steps=400, nb_max_episode_steps=400,
success_reward_threshold = 100): success_reward_threshold=100):
print("Testing for {} episodes".format(nb_episodes)) print("Testing for {} episodes".format(nb_episodes))
success_count = 0 success_count = 0
...@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase): ...@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
env.reset() env.reset()
if episode_reward >= success_reward_threshold: if episode_reward >= success_reward_threshold:
success_count += 1 success_count += 1
print("Episode {}: steps:{}, reward:{}".format(n+1, step, episode_reward)) print("Episode {}: steps:{}, reward:{}".format(
n + 1, step, episode_reward))
print ("\nPolicy succeeded {} times!".format(success_count)) print("\nPolicy succeeded {} times!".format(success_count))
print ("Failures due to:") print("Failures due to:")
print (termination_reason_counter) print(termination_reason_counter)
return [success_count,termination_reason_counter] return [success_count, termination_reason_counter]
def load_model(self, file_name="test_weights.h5f"): def load_model(self, file_name="test_weights.h5f"):
self.agent_model.load_weights(file_name) self.agent_model.load_weights(file_name)
...@@ -377,31 +396,43 @@ class DQNLearner(LearnerBase): ...@@ -377,31 +396,43 @@ class DQNLearner(LearnerBase):
return self.agent_model.get_modified_q_values(observation)[action] return self.agent_model.get_modified_q_values(observation)[action]
def get_q_value_using_option_alias(self, observation, option_alias): def get_q_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias) action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
return self.agent_model.get_modified_q_values(observation)[action_num] return self.agent_model.get_modified_q_values(observation)[action_num]
def get_softq_value_using_option_alias(self, observation, option_alias): def get_softq_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias) action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
q_values = self.agent_model.get_modified_q_values(observation) q_values = self.agent_model.get_modified_q_values(observation)
max_q_value = np.abs(np.max(q_values)) max_q_value = np.abs(np.max(q_values))
q_values = [np.exp(q_value/max_q_value) for q_value in q_values] q_values = [np.exp(q_value / max_q_value) for q_value in q_values]
relevant = q_values[action_num]/np.sum(q_values) relevant = q_values[action_num] / np.sum(q_values)
return relevant return relevant
class DQNAgentOverOptions(DQNAgent):
def __init__(self, model, low_level_policies, policy=None, test_policy=None, enable_double_dqn=True, enable_dueling_network=False, class DQNAgentOverOptions(DQNAgent):
dueling_type='avg', *args, **kwargs): def __init__(self,
super(DQNAgentOverOptions, self).__init__(model, policy, test_policy, enable_double_dqn, enable_dueling_network, model,
dueling_type, *args, **kwargs) low_level_policies,
policy=None,
test_policy=None,
enable_double_dqn=True,
enable_dueling_network=False,
dueling_type='avg',
*args,
**kwargs):
super(DQNAgentOverOptions, self).__init__(
model, policy, test_policy, enable_double_dqn,
enable_dueling_network, dueling_type, *args, **kwargs)
self.low_level_policies = low_level_policies self.low_level_policies = low_level_policies
if low_level_policies is not None: if low_level_policies is not None:
self.low_level_policy_aliases = list(self.low_level_policies.keys()) self.low_level_policy_aliases = list(
self.low_level_policies.keys())
def __get_invalid_node_indices(self): def __get_invalid_node_indices(self):
"""Returns a list of option indices that are invalid according to initiation conditions. """Returns a list of option indices that are invalid according to
""" initiation conditions."""
invalid_node_indices = list() invalid_node_indices = list()
for index, option_alias in enumerate(self.low_level_policy_aliases): for index, option_alias in enumerate(self.low_level_policy_aliases):
self.low_level_policies[option_alias].reset_maneuver() self.low_level_policies[option_alias].reset_maneuver()
...@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent): ...@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
q_values[node_index] = -np.inf q_values[node_index] = -np.inf
return q_values return q_values
from.policy_base import PolicyBase from .policy_base import PolicyBase
import numpy as np import numpy as np
......
from .controller_base import ControllerBase from .controller_base import ControllerBase
class ManualPolicy(ControllerBase): class ManualPolicy(ControllerBase):
"""Manual policy execution using nodes and edges.""" """Manual policy execution using nodes and edges."""
def __init__(self, env, low_level_policies, transition_adj, start_node_alias): def __init__(self, env, low_level_policies, transition_adj,
start_node_alias):
"""Constructor for manual policy execution. """Constructor for manual policy execution.
Args: Args:
...@@ -13,12 +15,13 @@ class ManualPolicy(ControllerBase): ...@@ -13,12 +15,13 @@ class ManualPolicy(ControllerBase):
start_node: starting node start_node: starting node
""" """
super(ManualPolicy, self).__init__(env, low_level_policies, start_node_alias) super(ManualPolicy, self).__init__(env, low_level_policies,
start_node_alias)
self.adj = transition_adj self.adj = transition_adj
def _transition(self): def _transition(self):
"""Check if the current node's termination condition is met and if """Check if the current node's termination condition is met and if it
it is possible to transition to another node, i.e. its initiation is possible to transition to another node, i.e. its initiation
condition is met. This is an internal function. condition is met. This is an internal function.
Returns the new node if a transition can happen, None otherwise Returns the new node if a transition can happen, None otherwise
......
...@@ -2,15 +2,15 @@ from .controller_base import ControllerBase ...@@ -2,15 +2,15 @@ from .controller_base import ControllerBase
import numpy as np import numpy as np
import pickle import pickle
class MCTSLearner(ControllerBase): class MCTSLearner(ControllerBase):
"""Monte Carlo Tree Search implementation using the UCB1 and """Monte Carlo Tree Search implementation using the UCB1 and progressive
progressive widening approach as explained in Paxton et al (2017). widening approach as explained in Paxton et al (2017)."""
"""
_ucb_vals = set() _ucb_vals = set()
def __init__(self, env, low_level_policies, def __init__(self, env, low_level_policies, start_node_alias,
start_node_alias, max_depth=10): max_depth=10):
"""Constructor for MCTSLearner. """Constructor for MCTSLearner.
Args: Args:
...@@ -22,10 +22,12 @@ class MCTSLearner(ControllerBase): ...@@ -22,10 +22,12 @@ class MCTSLearner(ControllerBase):
max_depth: max depth of the MCTS tree; default 10 levels max_depth: max depth of the MCTS tree; default 10 levels
""" """
super(MCTSLearner, self).__init__(env, low_level_policies, start_node_alias) super(MCTSLearner, self).__init__(env, low_level_policies,
start_node_alias)
self.controller_args_defaults = { self.controller_args_defaults = {
"predictor": None #P(s, o) learner class; forward pass should return the entire value from state s and option o "predictor":
None #P(s, o) learner class; forward pass should return the entire value from state s and option o
} }
self.max_depth = max_depth self.max_depth = max_depth
#: store current node alias #: store current node alias
...@@ -51,8 +53,14 @@ class MCTSLearner(ControllerBase): ...@@ -51,8 +53,14 @@ class MCTSLearner(ControllerBase):
self.adj[root_node_num] = set() # no children self.adj[root_node_num] = set() # no children
def save_model(self, file_name="mcts.pickle"): def save_model(self, file_name="mcts.pickle"):
to_backup = {'N': self.N, 'M': self.M, 'TR': self.TR, 'nodes': self.nodes, to_backup = {
'adj': self.adj, 'new_node_num': self.new_node_num} 'N': self.N,
'M': self.M,
'TR': self.TR,
'nodes': self.nodes,
'adj': self.adj,
'new_node_num': self.new_node_num
}
with open(file_name, 'wb') as handle: with open(file_name, 'wb') as handle:
pickle.dump(to_backup, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(to_backup, handle, protocol=pickle.HIGHEST_PROTOCOL)
...@@ -67,8 +75,8 @@ class MCTSLearner(ControllerBase): ...@@ -67,8 +75,8 @@ class MCTSLearner(ControllerBase):
self.new_node_num = to_restore['new_node_num'] self.new_node_num = to_restore['new_node_num']
def _create_node(self, low_level_policy): def _create_node(self, low_level_policy):
"""Create the node associated with curr_node_num, using the """Create the node associated with curr_node_num, using the given low
given low level policy. level policy.
Args: Args:
low_level_policy: the option's alias low_level_policy: the option's alias
...@@ -83,11 +91,10 @@ class MCTSLearner(ControllerBase): ...@@ -83,11 +91,10 @@ class MCTSLearner(ControllerBase):
return created_node_num, {"policy": low_level_policy} return created_node_num, {"policy": low_level_policy}
def _to_discrete(self, observation): def _to_discrete(self, observation):
"""Converts observation to a discrete observation tuple. Also """Converts observation to a discrete observation tuple. Also append
append (a) whether we are following a vehicle, and (b) whether (a) whether we are following a vehicle, and (b) whether there is a
there is a vehicle in the opposite lane in the approximately vehicle in the opposite lane in the approximately the same x position.
the same x position. These values will be useful for Follow These values will be useful for Follow and ChangeLane maneuvers.
and ChangeLane maneuvers.
Args: Args:
observation: observation tuple from the environment observation: observation tuple from the environment
...@@ -96,9 +103,9 @@ class MCTSLearner(ControllerBase): ...@@ -96,9 +103,9 @@ class MCTSLearner(ControllerBase):
""" """
dis_observation = '' dis_observation = ''
for item in observation[12:20]: for item in observation[12:20]:
if type(item)==bool: if type(item) == bool:
dis_observation += '1' if item is True else '0' dis_observation += '1' if item is True else '0'
if type(item)==int and item in [0, 1]: if type(item) == int and item in [0, 1]:
dis_observation += str(item) dis_observation += str(item)
env = self.current_node.env env = self.current_node.env
...@@ -128,9 +135,9 @@ class MCTSLearner(ControllerBase): ...@@ -128,9 +135,9 @@ class MCTSLearner(ControllerBase):
def _get_visitation_count(self, observation, option=None): def _get_visitation_count(self, observation, option=None):
"""Finds the visitation count of the discrete form of the observation. """Finds the visitation count of the discrete form of the observation.
If discrete observation not found, then inserted into self.N with If discrete observation not found, then inserted into self.N with value
value 0. Auto converts the observation into discrete form. If option 0. Auto converts the observation into discrete form. If option is not
is not None, then this uses self.M instead of self.N None, then this uses self.M instead of self.N.
Args: Args:
observation: observation tuple from the environment observation: observation tuple from the environment
...@@ -163,12 +170,14 @@ class MCTSLearner(ControllerBase): ...@@ -163,12 +170,14 @@ class MCTSLearner(ControllerBase):
dis_observation = self._to_discrete(observation) dis_observation = self._to_discrete(observation)
if (dis_observation, option) not in self.TR: if (dis_observation, option) not in self.TR:
self.TR[(dis_observation, option)] = 0 self.TR[(dis_observation, option)] = 0
return self.TR[(dis_observation, option)] / (1+self._get_visitation_count(observation, option)) return self.TR[(dis_observation, option)] / (
1 + self._get_visitation_count(observation, option))
def _get_possible_options(self): def _get_possible_options(self):
"""Returns a set of options that can be taken from the current node. """Returns a set of options that can be taken from the current node.
Goes through adjacency set of current node and finds which next nodes'
initiation condition is met. Goes through adjacency set of current node and finds which next
nodes' initiation condition is met.
""" """
all_options = set(self.low_level_policies.keys()) all_options = set(self.low_level_policies.keys())
...@@ -201,9 +210,9 @@ class MCTSLearner(ControllerBase): ...@@ -201,9 +210,9 @@ class MCTSLearner(ControllerBase):
return visited_aliases return visited_aliases
def _ucb_adjusted_q(self, observation, C=1): def _ucb_adjusted_q(self, observation, C=1):
"""Computes Q_star(observation, option_i) plus the UCB term, which """Computes Q_star(observation, option_i) plus the UCB term, which is
is C*[predictor(observation, option_i)]/[1+N(observation, option_i)], C*[predictor(observation, option_i)]/[1+N(observation, option_i)], for
for all option_i in the adjacency set of the current node. all option_i in the adjacency set of the current node.
Args: Args:
observation: observation tuple from the environment observation: observation tuple from the environment
...@@ -218,7 +227,8 @@ class MCTSLearner(ControllerBase): ...@@ -218,7 +227,8 @@ class MCTSLearner(ControllerBase):
next_option_nums = self.adj[self.curr_node_num] next_option_nums = self.adj[self.curr_node_num]
for next_option_num in next_option_nums: for next_option_num in next_option_nums:
next_option = self.nodes[next_option_num]["policy"] next_option = self.nodes[next_option_num]["policy"]
Q1[next_option] = (self._get_q_star(observation, next_option)+200)/400 Q1[next_option] = (
self._get_q_star(observation, next_option) + 200) / 400
Q[(dis_observation, next_option)] = \ Q[(dis_observation, next_option)] = \
Q1[next_option] Q1[next_option]
Q2[next_option] = C * \ Q2[next_option] = C * \
...@@ -243,11 +253,11 @@ class MCTSLearner(ControllerBase): ...@@ -243,11 +253,11 @@ class MCTSLearner(ControllerBase):
relevant_rewards = [value for key, value in self.TR.items() \ relevant_rewards = [value for key, value in self.TR.items() \
if key[0] == dis_observation] if key[0] == dis_observation]
sum_rewards = np.sum(relevant_rewards) sum_rewards = np.sum(relevant_rewards)
return sum_rewards / (1+self._get_visitation_count(observation)) return sum_rewards / (1 + self._get_visitation_count(observation))
def _select(self, observation, depth=0, visualize=False): def _select(self, observation, depth=0, visualize=False):
"""MCTS selection function. For representation, we only use """MCTS selection function. For representation, we only use the
the discrete part of the observation. discrete part of the observation.
Args: Args:
observation: observation tuple from the environment observation: observation tuple from the environment
...@@ -266,7 +276,8 @@ class MCTSLearner(ControllerBase): ...@@ -266,7 +276,8 @@ class MCTSLearner(ControllerBase):
if is_terminal or max_depth_reached: if is_terminal or max_depth_reached:
# print('MCTS went %d nodes deep' % depth) # print('MCTS went %d nodes deep' % depth)
return self._value(observation), 0 # TODO: replace with final goal reward return self._value(