Commit 72e44d55 authored by Ashish Gaurav's avatar Ashish Gaurav

format using yapf

parent e1fdb162
......@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
self.log_path = log_path
self.env = DummyVecEnv([lambda: env]) #PPO2 requried a vectorized environment for parallel training
self.env = DummyVecEnv([
lambda: env
]) #PPO2 requried a vectorized environment for parallel training
self.agent_model = self.create_agent(policy, tensorboard)
def get_default_policy(self):
......@@ -52,7 +54,8 @@ class PPO2Agent(LearnerBase):
stable_baselines PPO2 object
"""
if tensorboard:
return PPO2(policy, self.env, verbose=1, tensorboard_log=self.log_path)
return PPO2(
policy, self.env, verbose=1, tensorboard_log=self.log_path)
else:
return PPO2(policy, self.env, verbose=1)
......@@ -100,7 +103,8 @@ class PPO2Agent(LearnerBase):
episode_rewards[-1] += rewards[0]
if dones[0] or current_step > nb_max_episode_steps:
obs = self.env.reset()
print ("Episode ", current_episode, "reward: ", episode_rewards[-1])
print("Episode ", current_episode, "reward: ",
episode_rewards[-1])
episode_rewards.append(0.0)
current_episode += 1
current_step = 0
......
......@@ -50,11 +50,13 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info
'''
def step_current_node(self, visualize_low_level_steps=False):
total_reward = 0
self.node_terminal_state_reached = False
while not self.node_terminal_state_reached:
observation, reward, terminal, info = self.low_level_step_current_node()
observation, reward, terminal, info = self.low_level_step_current_node(
)
if visualize_low_level_steps:
self.env.render()
total_reward += reward
......@@ -70,9 +72,12 @@ class ControllerBase(PolicyBase):
Returns state after one step, step reward, episode_termination_flag, info
'''
def low_level_step_current_node(self):
u_ego = self.current_node.low_level_policy(self.current_node.get_reduced_features_tuple())
u_ego = self.current_node.low_level_policy(
self.current_node.get_reduced_features_tuple())
feature, R, terminal, info = self.current_node.step(u_ego)
self.node_terminal_state_reached = terminal
return self.env.get_features_tuple(), R, self.env.termination_condition, info
return self.env.get_features_tuple(
), R, self.env.termination_condition, info
......@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
"oup_mu": 0, # OrnsteinUhlenbeckProcess mu
"oup_sigma": 1, # OrnsteinUhlenbeckProcess sigma
"oup_sigma_min": 0.5, # OrnsteinUhlenbeckProcess sigma min
"oup_annealing_steps": 500000, # OrnsteinUhlenbeckProcess n-step annealing
"oup_annealing_steps":
500000, # OrnsteinUhlenbeckProcess n-step annealing
"nb_steps_warmup_critic": 100, # steps for critic to warmup
"nb_steps_warmup_actor": 100, # steps for actor to warmup
"target_model_update": 1e-3 # target model update frequency
......@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase):
target_model_update=1e-3)
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
agent.compile([Adam(lr=self.lr*1e-2, clipnorm=1.), Adam(lr=self.lr, clipnorm=1.)], metrics=['mae'])
agent.compile(
[
Adam(lr=self.lr * 1e-2, clipnorm=1.),
Adam(lr=self.lr, clipnorm=1.)
],
metrics=['mae'])
return agent
def train(self,
env,
nb_steps=1000000,
visualize=False,
verbose=1,
log_interval=10000,
nb_max_episode_steps=200,
model_checkpoints=False,
checkpoint_interval=100000,
tensorboard=False):
env,
nb_steps=1000000,
visualize=False,
verbose=1,
log_interval=10000,
nb_max_episode_steps=200,
model_checkpoints=False,
checkpoint_interval=100000,
tensorboard=False):
callbacks = []
if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)]
callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')]
......@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase):
Returns:
KerasRL DQN object
"""
agent = DQNAgentOverOptions(model=model, low_level_policies=self.low_level_policies,
nb_actions=self.nb_actions, memory=memory,
nb_steps_warmup=self.nb_steps_warmup, target_model_update=self.target_model_update,
policy=policy, enable_dueling_network=True)
agent = DQNAgentOverOptions(
model=model,
low_level_policies=self.low_level_policies,
nb_actions=self.nb_actions,
memory=memory,
nb_steps_warmup=self.nb_steps_warmup,
target_model_update=self.target_model_update,
policy=policy,
enable_dueling_network=True)
agent.compile(Adam(lr=self.lr), metrics=['mae'])
return agent
def train(self,
env,
nb_steps=1000000,
visualize=False,
nb_max_episode_steps=200,
tensorboard=False,
model_checkpoints=False,
checkpoint_interval=10000):
env,
nb_steps=1000000,
visualize=False,
nb_max_episode_steps=200,
tensorboard=False,
model_checkpoints=False,
checkpoint_interval=10000):
callbacks = []
if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)]
callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')]
......@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
nb_episodes=5,
visualize=True,
nb_max_episode_steps=400,
success_reward_threshold = 100):
success_reward_threshold=100):
print("Testing for {} episodes".format(nb_episodes))
success_count = 0
......@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
env.reset()
if episode_reward >= success_reward_threshold:
success_count += 1
print("Episode {}: steps:{}, reward:{}".format(n+1, step, episode_reward))
print("Episode {}: steps:{}, reward:{}".format(
n + 1, step, episode_reward))
print ("\nPolicy succeeded {} times!".format(success_count))
print ("Failures due to:")
print (termination_reason_counter)
print("\nPolicy succeeded {} times!".format(success_count))
print("Failures due to:")
print(termination_reason_counter)
return [success_count,termination_reason_counter]
return [success_count, termination_reason_counter]
def load_model(self, file_name="test_weights.h5f"):
self.agent_model.load_weights(file_name)
......@@ -377,27 +396,39 @@ class DQNLearner(LearnerBase):
return self.agent_model.get_modified_q_values(observation)[action]
def get_q_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias)
action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
return self.agent_model.get_modified_q_values(observation)[action_num]
def get_softq_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias)
action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
q_values = self.agent_model.get_modified_q_values(observation)
max_q_value = np.abs(np.max(q_values))
q_values = [np.exp(q_value/max_q_value) for q_value in q_values]
relevant = q_values[action_num]/np.sum(q_values)
q_values = [np.exp(q_value / max_q_value) for q_value in q_values]
relevant = q_values[action_num] / np.sum(q_values)
return relevant
class DQNAgentOverOptions(DQNAgent):
def __init__(self, model, low_level_policies, policy=None, test_policy=None, enable_double_dqn=True, enable_dueling_network=False,
dueling_type='avg', *args, **kwargs):
super(DQNAgentOverOptions, self).__init__(model, policy, test_policy, enable_double_dqn, enable_dueling_network,
dueling_type, *args, **kwargs)
class DQNAgentOverOptions(DQNAgent):
def __init__(self,
model,
low_level_policies,
policy=None,
test_policy=None,
enable_double_dqn=True,
enable_dueling_network=False,
dueling_type='avg',
*args,
**kwargs):
super(DQNAgentOverOptions, self).__init__(
model, policy, test_policy, enable_double_dqn,
enable_dueling_network, dueling_type, *args, **kwargs)
self.low_level_policies = low_level_policies
if low_level_policies is not None:
self.low_level_policy_aliases = list(self.low_level_policies.keys())
self.low_level_policy_aliases = list(
self.low_level_policies.keys())
def __get_invalid_node_indices(self):
"""Returns a list of option indices that are invalid according to initiation conditions.
......@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
q_values[node_index] = -np.inf
return q_values
from.policy_base import PolicyBase
from .policy_base import PolicyBase
import numpy as np
......@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase):
setattr(self, prop, kwargs.get(prop, default))
def train(self,
env,
nb_steps=50000,
visualize=False,
nb_max_episode_steps=200):
env,
nb_steps=50000,
visualize=False,
nb_max_episode_steps=200):
"""Train the learning agent on the environment.
Args:
......
from .controller_base import ControllerBase
class ManualPolicy(ControllerBase):
"""Manual policy execution using nodes and edges."""
def __init__(self, env, low_level_policies, transition_adj, start_node_alias):
def __init__(self, env, low_level_policies, transition_adj,
start_node_alias):
"""Constructor for manual policy execution.
Args:
......@@ -13,7 +15,8 @@ class ManualPolicy(ControllerBase):
start_node: starting node
"""
super(ManualPolicy, self).__init__(env, low_level_policies, start_node_alias)
super(ManualPolicy, self).__init__(env, low_level_policies,
start_node_alias)
self.adj = transition_adj
def _transition(self):
......@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase):
new_node = self._transition()
if new_node is not None:
self.current_node = new_node
\ No newline at end of file
self.current_node = new_node
......@@ -2,6 +2,7 @@ from .controller_base import ControllerBase
import numpy as np
import pickle
class MCTSLearner(ControllerBase):
"""Monte Carlo Tree Search implementation using the UCB1 and
progressive widening approach as explained in Paxton et al (2017).
......@@ -9,8 +10,8 @@ class MCTSLearner(ControllerBase):
_ucb_vals = set()
def __init__(self, env, low_level_policies,
start_node_alias, max_depth=10):
def __init__(self, env, low_level_policies, start_node_alias,
max_depth=10):
"""Constructor for MCTSLearner.
Args:
......@@ -22,10 +23,12 @@ class MCTSLearner(ControllerBase):
max_depth: max depth of the MCTS tree; default 10 levels
"""
super(MCTSLearner, self).__init__(env, low_level_policies, start_node_alias)
super(MCTSLearner, self).__init__(env, low_level_policies,
start_node_alias)
self.controller_args_defaults = {
"predictor": None #P(s, o) learner class; forward pass should return the entire value from state s and option o
"predictor":
None #P(s, o) learner class; forward pass should return the entire value from state s and option o
}
self.max_depth = max_depth
#: store current node alias
......@@ -48,11 +51,17 @@ class MCTSLearner(ControllerBase):
# populate root node
root_node_num, root_node_info = self._create_node(self.curr_node_alias)
self.nodes[root_node_num] = root_node_info
self.adj[root_node_num] = set() # no children
self.adj[root_node_num] = set() # no children
def save_model(self, file_name="mcts.pickle"):
to_backup = {'N': self.N, 'M': self.M, 'TR': self.TR, 'nodes': self.nodes,
'adj': self.adj, 'new_node_num': self.new_node_num}
to_backup = {
'N': self.N,
'M': self.M,
'TR': self.TR,
'nodes': self.nodes,
'adj': self.adj,
'new_node_num': self.new_node_num
}
with open(file_name, 'wb') as handle:
pickle.dump(to_backup, handle, protocol=pickle.HIGHEST_PROTOCOL)
......@@ -96,9 +105,9 @@ class MCTSLearner(ControllerBase):
"""
dis_observation = ''
for item in observation[12:20]:
if type(item)==bool:
if type(item) == bool:
dis_observation += '1' if item is True else '0'
if type(item)==int and item in [0, 1]:
if type(item) == int and item in [0, 1]:
dis_observation += str(item)
env = self.current_node.env
......@@ -163,7 +172,8 @@ class MCTSLearner(ControllerBase):
dis_observation = self._to_discrete(observation)
if (dis_observation, option) not in self.TR:
self.TR[(dis_observation, option)] = 0
return self.TR[(dis_observation, option)] / (1+self._get_visitation_count(observation, option))
return self.TR[(dis_observation, option)] / (
1 + self._get_visitation_count(observation, option))
def _get_possible_options(self):
"""Returns a set of options that can be taken from the current node.
......@@ -172,7 +182,7 @@ class MCTSLearner(ControllerBase):
"""
all_options = set(self.low_level_policies.keys())
# Filter nodes whose initiation condition are true
filtered_options = set()
for option_alias in all_options:
......@@ -211,14 +221,15 @@ class MCTSLearner(ControllerBase):
Returns Q values for next nodes
"""
Q = {}
Q1, Q2 = {}, {} # debug
Q1, Q2 = {}, {} # debug
dis_observation = self._to_discrete(observation)
next_option_nums = self.adj[self.curr_node_num]
for next_option_num in next_option_nums:
next_option = self.nodes[next_option_num]["policy"]
Q1[next_option] = (self._get_q_star(observation, next_option)+200)/400
Q1[next_option] = (
self._get_q_star(observation, next_option) + 200) / 400
Q[(dis_observation, next_option)] = \
Q1[next_option]
Q2[next_option] = C * \
......@@ -243,7 +254,7 @@ class MCTSLearner(ControllerBase):
relevant_rewards = [value for key, value in self.TR.items() \
if key[0] == dis_observation]
sum_rewards = np.sum(relevant_rewards)
return sum_rewards / (1+self._get_visitation_count(observation))
return sum_rewards / (1 + self._get_visitation_count(observation))
def _select(self, observation, depth=0, visualize=False):
"""MCTS selection function. For representation, we only use
......@@ -266,7 +277,8 @@ class MCTSLearner(ControllerBase):
if is_terminal or max_depth_reached:
# print('MCTS went %d nodes deep' % depth)
return self._value(observation), 0 # TODO: replace with final goal reward
return self._value(
observation), 0 # TODO: replace with final goal reward
Ns = self._get_visitation_count(observation)
Nchildren = len(self.adj[self.curr_node_num])
......@@ -288,19 +300,20 @@ class MCTSLearner(ControllerBase):
self.adj[self.curr_node_num].add(new_node_num)
# Find o_star and do a transition, i.e. update curr_node
# Simulate / lookup; first change next
next_observation, episode_R, o_star = self.do_transition(observation,
visualize=visualize)
# Simulate / lookup; first change next
next_observation, episode_R, o_star = self.do_transition(
observation, visualize=visualize)
# Recursively select next node
remaining_v, all_ep_R = self._select(next_observation, depth+1, visualize=visualize)
remaining_v, all_ep_R = self._select(
next_observation, depth + 1, visualize=visualize)
# Update values
self.N[dis_observation] += 1
self.M[(dis_observation, o_star)] += 1
self.TR[(dis_observation, o_star)] += (episode_R + remaining_v)
return self._value(observation), all_ep_R+episode_R
return self._value(observation), all_ep_R + episode_R
def traverse(self, observation, visualize=False):
"""Do a complete traversal from root to leaf. Assumes the
......@@ -368,4 +381,4 @@ class MCTSLearner(ControllerBase):
next_keys, next_values = list(Q1.keys()), list(Q1.values())
o_star = next_keys[np.argmax(next_values)]
print(Q1)
return o_star
\ No newline at end of file
return o_star
......@@ -3,6 +3,7 @@ from .mcts_learner import MCTSLearner
import tqdm
import numpy as np
class OnlineMCTSController(ControllerBase):
"""Online MCTS"""
......@@ -13,12 +14,13 @@ class OnlineMCTSController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super(OnlineMCTSController, self).__init__(env, low_level_policies, start_node_alias)
super(OnlineMCTSController, self).__init__(env, low_level_policies,
start_node_alias)
self.curr_node_alias = start_node_alias
self.controller_args_defaults = {
"predictor": None,
"max_depth": 5, # MCTS depth
"nb_traversals": 30, # MCTS traversals before decision
"max_depth": 5, # MCTS depth
"nb_traversals": 30, # MCTS traversals before decision
}
def set_current_node(self, node_alias):
......@@ -48,12 +50,14 @@ class OnlineMCTSController(ControllerBase):
env_before_mcts = orig_env.copy()
self.change_low_level_references(env_before_mcts)
print('Current Node: %s' % self.curr_node_alias)
mcts = MCTSLearner(self.env, self.low_level_policies, self.curr_node_alias)
mcts = MCTSLearner(self.env, self.low_level_policies,
self.curr_node_alias)
mcts.max_depth = self.max_depth
mcts.set_controller_args(predictor = self.predictor)
mcts.set_controller_args(predictor=self.predictor)
# Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for num_epoch in range(self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
for num_epoch in range(
self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
mcts.curr_node_num = 0
env_begin_epoch = env_before_mcts.copy()
self.change_low_level_references(env_begin_epoch)
......@@ -63,6 +67,7 @@ class OnlineMCTSController(ControllerBase):
# Find the nodes from the root node
mcts.curr_node_num = 0
print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = mcts.get_best_node(self.env.get_features_tuple(), use_ucb=False)
node_after_transition = mcts.get_best_node(
self.env.get_features_tuple(), use_ucb=False)
print('MCTS suggested next option: %s' % node_after_transition)
self.set_current_node(node_after_transition)
\ No newline at end of file
self.set_current_node(node_after_transition)
class PolicyBase:
"""Abstract policy base from which every policy backend is defined
and inherited."""
\ No newline at end of file
"""Abstract policy base from which every policy backend is defined
and inherited."""
......@@ -11,7 +11,8 @@ class RLController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super(RLController, self).__init__(env, low_level_policies, start_node_alias)
super(RLController, self).__init__(env, low_level_policies,
start_node_alias)
self.low_level_policy_aliases = list(self.low_level_policies.keys())
self.trained_policy = None
self.node_terminal_state_reached = False
......@@ -32,6 +33,8 @@ class RLController(ControllerBase):
if self.trained_policy is None:
raise Exception(self.__class__.__name__ + \
"trained_policy is not set. Use set_trained_policy().")
node_index_after_transition = self.trained_policy(self.env.get_features_tuple())
self.set_current_node(self.low_level_policy_aliases[node_index_after_transition])
self.node_terminal_state_reached = False
\ No newline at end of file
node_index_after_transition = self.trained_policy(
self.env.get_features_tuple())
self.set_current_node(
self.low_level_policy_aliases[node_index_after_transition])
self.node_terminal_state_reached = False
......@@ -6,19 +6,22 @@ class GymCompliantEnvBase:
""" Gym compliant step function which
will be implemented in the subclass.
"""
raise NotImplemented(self.__class__.__name__ + "step is not implemented.")
raise NotImplemented(self.__class__.__name__ +
"step is not implemented.")
def reset(self):
""" Gym compliant reset function which
will be implemented in the subclass.
"""
raise NotImplemented(self.__class__.__name__ + "reset is not implemented.")
raise NotImplemented(self.__class__.__name__ +
"reset is not implemented.")
def render(self):
""" Gym compliant step function which
will be implemented in the subclass.
"""
raise NotImplemented(self.__class__.__name__ + "render is not implemented.")
raise NotImplemented(self.__class__.__name__ +
"render is not implemented.")
class EpisodicEnvBase(GymCompliantEnvBase):
......@@ -77,13 +80,16 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return
if self.terminal_reward_type == 'min':
self._r_terminal = r_obs if self._r_terminal is None else min(self._r_terminal, r_obs)
self._r_terminal = r_obs if self._r_terminal is None else min(
self._r_terminal, r_obs)
elif self.terminal_reward_type == 'max':
self._r_terminal = r_obs if self._r_terminal is None else max(self._r_terminal, r_obs)
self._r_terminal = r_obs if self._r_terminal is None else max(
self._r_terminal, r_obs)
elif self.terminal_reward_type == 'sum':
self._r_terminal = r_obs if self._r_terminal is None else self._r_terminal + r_obs
else:
raise AssertionError("The terminal_reward_type has to be 'min', 'max', or 'sum'")
raise AssertionError(
"The terminal_reward_type has to be 'min', 'max', or 'sum'")
def step(self, u):
# the penalty is a negative reward.
......@@ -118,7 +124,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
if LTL_precondition.enabled:
LTL_precondition.check_incremental(self.__mc_AP)
if LTL_precondition.result == Parser.FALSE:
self._terminal_reward_superposition(EpisodicEnvBase._reward(LTL_precondition.penalty))
self._terminal_reward_superposition(
EpisodicEnvBase._reward(LTL_precondition.penalty))
violate = True
info['ltl_violation'] = LTL_precondition.str
# print("\nViolation of \"" + LTL_precondition.str + "\"")
......@@ -177,7 +184,6 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return True
return False
# TODO: replace these confusing methods reward and penalty, or not to use both reward and penalty for the property naming.
@staticmethod
def _reward(penalty):
......@@ -186,4 +192,3 @@ class EpisodicEnvBase(GymCompliantEnvBase):
@staticmethod
def _penalty(reward):
return None if reward is None else -reward
from .simple_intersection_env import SimpleIntersectionEnv
\ No newline at end of file
from .simple_intersection_env import SimpleIntersectionEnv
......@@ -34,7 +34,7 @@ con_ego_feature_dict = {
'pos_stop_region': 11
}
#: dis_ego_feature_dict contains all indexing information regarding
#: dis_ego_feature_dict contains all indexing information regarding
# each element of the ego vehicle's discrete feature vector.
#
# * not_in_stop_region: True if the ego is in stop region;
......@@ -69,7 +69,8 @@ ego_feature_dict = dict()
for key in con_ego_feature_dict.keys():
ego_feature_dict[key] = con_ego_feature_dict[key]
for key in dis_ego_feature_dict.keys():
ego_feature_dict[key] = dis_ego_feature_dict[key] + len(con_ego_feature_dict)
ego_feature_dict[
key] = dis_ego_feature_dict[key] + len(con_ego_feature_dict)
ego_feature_len = len(ego_feature_dict)
other_veh_feature_len = len(other_veh_feature_dict)
......@@ -81,8 +82,8 @@ def extract_ego_features(features_tuple, *args):
def extract_other_veh_features(features_tuple, veh_index, *args):
return tuple(features_tuple[ego_feature_len +
(veh_index - 1) * other_veh_feature_len + other_veh_feature_dict[key]]
for key in args)
(veh_index - 1) * other_veh_feature_len +
other_veh_feature_dict[key]] for key in args)
class OtherVehFeatures(object):
......@@ -109,20 +110,27 @@ class Features(object):
ego = env.ego
v_ref = env.v_ref
target_lane = env.target_lane
stop_region_length = np.abs(rd.hlanes.stop_region[1] - rd.hlanes.stop_region[0])
stop_region_length = np.abs(rd.hlanes.stop_region[1] -
rd.hlanes.stop_region[0])
assert (target_lane == True) or (target_lane == False)
# Continuous feature vector of the ego-vehicle.
self.con_ego = (
# TODO: separate this into the two, one for Default and one for Finish, and try to learn.
min(rd.hlanes.stop_region[1] - ego.x, 50)/50,
ego.v, v_ref, ego.y - rd.hlanes.centres[target_lane],
ego.psi, ego.v * np.tan(ego.psi) / VEHICLE_WHEEL_BASE,
ego.theta, ego.APs['lane'],
min(rd.hlanes.stop_region[1] - ego.x, 50) / 50,
ego.v,
v_ref,
ego.y - rd.hlanes.centres[target_lane],
ego.psi,
ego.v * np.tan(ego.psi) / VEHICLE_WHEEL_BASE,
ego.theta,
ego.APs['lane'],
ego.y - rd.hlanes.centres[ego.APs['lane']],
ego.acc, ego.psi_dot,
np.clip(rd.hlanes.stop_region[1] - ego.x, 0, stop_region_length)/stop_region_length)
ego.acc,
ego.psi_dot,
np.clip(rd.hlanes.stop_region[1] - ego.x, 0, stop_region_length) /
stop_region_length)
# Discrete feature vector of the ego-vehicle.
self.dis_ego = (
......@@ -141,9 +149,9 @@ class Features(object):
# Features of the other vehicles (relative distance (x,y),
# velocity, acceleration, waited_j).
for veh in env.vehs[1:]:
self.other_vehs += (OtherVehFeatures(
ego.x - veh.x, ego.y - veh.y, veh.v, veh.acc,
veh.waited_count * DT), )
self.other_vehs += (OtherVehFeatures(ego.x - veh.x, ego.y - veh.y,
veh.v, veh.acc,
veh.waited_count * DT), )
def reset(self, env):
self.__init__(env)
......@@ -167,7 +175,7 @@ class Features(object):
other_veh.acc, other_veh.waited_time)
# Add buffer features to make a fixed length feature vector
for i in range(MAX_NUM_VEHICLES-len(self.other_vehs)):
for i in range(MAX_NUM_VEHICLES - len(self.other_vehs)):
feature += (0.0, 0.0, 0.0, 0.0, -1)
return feature
......@@ -54,6 +54,7 @@ class Route(object):
def pos_range(self):
return self.min_pos, self.max_pos
#: the intersection is posed at (0, 0) in the coordinate.
# The variable "intersection_pos" indicates this position of
# the intersection, but not used at all in the current implementation.
......@@ -78,18 +79,19 @@ speed_limit = 11.176
# (i.e., y-axis flipped). The coordination system we consider
# is the same: y-axis flipped.
hwidth = 4.0