Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
Open sidebar
wise-lab
wise-move
Commits
72e44d55
Commit
72e44d55
authored
Nov 19, 2018
by
Ashish Gaurav
Browse files
format using yapf
parent
e1fdb162
Changes
35
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
537 additions
and
299 deletions
+537
-299
backends/baselines_learner.py
backends/baselines_learner.py
+7
-3
backends/controller_base.py
backends/controller_base.py
+8
-3
backends/kerasrl_learner.py
backends/kerasrl_learner.py
+72
-43
backends/learner_base.py
backends/learner_base.py
+5
-5
backends/manual_policy.py
backends/manual_policy.py
+6
-3
backends/mcts_learner.py
backends/mcts_learner.py
+35
-22
backends/online_mcts_controller.py
backends/online_mcts_controller.py
+13
-8
backends/policy_base.py
backends/policy_base.py
+2
-2
backends/rl_controller.py
backends/rl_controller.py
+7
-4
env/env_base.py
env/env_base.py
+14
-9
env/simple_intersection/__init__.py
env/simple_intersection/__init__.py
+1
-1
env/simple_intersection/features.py
env/simple_intersection/features.py
+23
-15
env/simple_intersection/road_geokinemetry.py
env/simple_intersection/road_geokinemetry.py
+10
-8
env/simple_intersection/road_networks.py
env/simple_intersection/road_networks.py
+20
-10
env/simple_intersection/shapes.py
env/simple_intersection/shapes.py
+4
-2
env/simple_intersection/simple_intersection_env.py
env/simple_intersection/simple_intersection_env.py
+155
-85
env/simple_intersection/utilities.py
env/simple_intersection/utilities.py
+5
-5
env/simple_intersection/vehicle_networks.py
env/simple_intersection/vehicle_networks.py
+7
-3
env/simple_intersection/vehicles.py
env/simple_intersection/vehicles.py
+31
-18
high_level_policy_main.py
high_level_policy_main.py
+112
-50
No files found.
backends/baselines_learner.py
View file @
72e44d55
...
...
@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
self
.
log_path
=
log_path
self
.
env
=
DummyVecEnv
([
lambda
:
env
])
#PPO2 requried a vectorized environment for parallel training
self
.
env
=
DummyVecEnv
([
lambda
:
env
])
#PPO2 requried a vectorized environment for parallel training
self
.
agent_model
=
self
.
create_agent
(
policy
,
tensorboard
)
def
get_default_policy
(
self
):
...
...
@@ -52,7 +54,8 @@ class PPO2Agent(LearnerBase):
stable_baselines PPO2 object
"""
if
tensorboard
:
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
,
tensorboard_log
=
self
.
log_path
)
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
,
tensorboard_log
=
self
.
log_path
)
else
:
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
)
...
...
@@ -100,7 +103,8 @@ class PPO2Agent(LearnerBase):
episode_rewards
[
-
1
]
+=
rewards
[
0
]
if
dones
[
0
]
or
current_step
>
nb_max_episode_steps
:
obs
=
self
.
env
.
reset
()
print
(
"Episode "
,
current_episode
,
"reward: "
,
episode_rewards
[
-
1
])
print
(
"Episode "
,
current_episode
,
"reward: "
,
episode_rewards
[
-
1
])
episode_rewards
.
append
(
0.0
)
current_episode
+=
1
current_step
=
0
...
...
backends/controller_base.py
View file @
72e44d55
...
...
@@ -50,11 +50,13 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info
'''
def
step_current_node
(
self
,
visualize_low_level_steps
=
False
):
total_reward
=
0
self
.
node_terminal_state_reached
=
False
while
not
self
.
node_terminal_state_reached
:
observation
,
reward
,
terminal
,
info
=
self
.
low_level_step_current_node
()
observation
,
reward
,
terminal
,
info
=
self
.
low_level_step_current_node
(
)
if
visualize_low_level_steps
:
self
.
env
.
render
()
total_reward
+=
reward
...
...
@@ -70,9 +72,12 @@ class ControllerBase(PolicyBase):
Returns state after one step, step reward, episode_termination_flag, info
'''
def
low_level_step_current_node
(
self
):
u_ego
=
self
.
current_node
.
low_level_policy
(
self
.
current_node
.
get_reduced_features_tuple
())
u_ego
=
self
.
current_node
.
low_level_policy
(
self
.
current_node
.
get_reduced_features_tuple
())
feature
,
R
,
terminal
,
info
=
self
.
current_node
.
step
(
u_ego
)
self
.
node_terminal_state_reached
=
terminal
return
self
.
env
.
get_features_tuple
(),
R
,
self
.
env
.
termination_condition
,
info
return
self
.
env
.
get_features_tuple
(
),
R
,
self
.
env
.
termination_condition
,
info
backends/kerasrl_learner.py
View file @
72e44d55
...
...
@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
"oup_mu"
:
0
,
# OrnsteinUhlenbeckProcess mu
"oup_sigma"
:
1
,
# OrnsteinUhlenbeckProcess sigma
"oup_sigma_min"
:
0.5
,
# OrnsteinUhlenbeckProcess sigma min
"oup_annealing_steps"
:
500000
,
# OrnsteinUhlenbeckProcess n-step annealing
"oup_annealing_steps"
:
500000
,
# OrnsteinUhlenbeckProcess n-step annealing
"nb_steps_warmup_critic"
:
100
,
# steps for critic to warmup
"nb_steps_warmup_actor"
:
100
,
# steps for actor to warmup
"target_model_update"
:
1e-3
# target model update frequency
...
...
@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase):
target_model_update
=
1e-3
)
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
agent
.
compile
([
Adam
(
lr
=
self
.
lr
*
1e-2
,
clipnorm
=
1.
),
Adam
(
lr
=
self
.
lr
,
clipnorm
=
1.
)],
metrics
=
[
'mae'
])
agent
.
compile
(
[
Adam
(
lr
=
self
.
lr
*
1e-2
,
clipnorm
=
1.
),
Adam
(
lr
=
self
.
lr
,
clipnorm
=
1.
)
],
metrics
=
[
'mae'
])
return
agent
def
train
(
self
,
env
,
nb_steps
=
1000000
,
visualize
=
False
,
verbose
=
1
,
log_interval
=
10000
,
nb_max_episode_steps
=
200
,
model_checkpoints
=
False
,
checkpoint_interval
=
100000
,
tensorboard
=
False
):
env
,
nb_steps
=
1000000
,
visualize
=
False
,
verbose
=
1
,
log_interval
=
10000
,
nb_max_episode_steps
=
200
,
model_checkpoints
=
False
,
checkpoint_interval
=
100000
,
tensorboard
=
False
):
callbacks
=
[]
if
model_checkpoints
:
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)]
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)
]
if
tensorboard
:
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
...
...
@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase):
Returns:
KerasRL DQN object
"""
agent
=
DQNAgentOverOptions
(
model
=
model
,
low_level_policies
=
self
.
low_level_policies
,
nb_actions
=
self
.
nb_actions
,
memory
=
memory
,
nb_steps_warmup
=
self
.
nb_steps_warmup
,
target_model_update
=
self
.
target_model_update
,
policy
=
policy
,
enable_dueling_network
=
True
)
agent
=
DQNAgentOverOptions
(
model
=
model
,
low_level_policies
=
self
.
low_level_policies
,
nb_actions
=
self
.
nb_actions
,
memory
=
memory
,
nb_steps_warmup
=
self
.
nb_steps_warmup
,
target_model_update
=
self
.
target_model_update
,
policy
=
policy
,
enable_dueling_network
=
True
)
agent
.
compile
(
Adam
(
lr
=
self
.
lr
),
metrics
=
[
'mae'
])
return
agent
def
train
(
self
,
env
,
nb_steps
=
1000000
,
visualize
=
False
,
nb_max_episode_steps
=
200
,
tensorboard
=
False
,
model_checkpoints
=
False
,
checkpoint_interval
=
10000
):
env
,
nb_steps
=
1000000
,
visualize
=
False
,
nb_max_episode_steps
=
200
,
tensorboard
=
False
,
model_checkpoints
=
False
,
checkpoint_interval
=
10000
):
callbacks
=
[]
if
model_checkpoints
:
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)]
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)
]
if
tensorboard
:
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
...
...
@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
nb_episodes
=
5
,
visualize
=
True
,
nb_max_episode_steps
=
400
,
success_reward_threshold
=
100
):
success_reward_threshold
=
100
):
print
(
"Testing for {} episodes"
.
format
(
nb_episodes
))
success_count
=
0
...
...
@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
env
.
reset
()
if
episode_reward
>=
success_reward_threshold
:
success_count
+=
1
print
(
"Episode {}: steps:{}, reward:{}"
.
format
(
n
+
1
,
step
,
episode_reward
))
print
(
"Episode {}: steps:{}, reward:{}"
.
format
(
n
+
1
,
step
,
episode_reward
))
print
(
"
\n
Policy succeeded {} times!"
.
format
(
success_count
))
print
(
"Failures due to:"
)
print
(
termination_reason_counter
)
print
(
"
\n
Policy succeeded {} times!"
.
format
(
success_count
))
print
(
"Failures due to:"
)
print
(
termination_reason_counter
)
return
[
success_count
,
termination_reason_counter
]
return
[
success_count
,
termination_reason_counter
]
def
load_model
(
self
,
file_name
=
"test_weights.h5f"
):
self
.
agent_model
.
load_weights
(
file_name
)
...
...
@@ -377,27 +396,39 @@ class DQNLearner(LearnerBase):
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action
]
def
get_q_value_using_option_alias
(
self
,
observation
,
option_alias
):
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action_num
]
def
get_softq_value_using_option_alias
(
self
,
observation
,
option_alias
):
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
q_values
=
self
.
agent_model
.
get_modified_q_values
(
observation
)
max_q_value
=
np
.
abs
(
np
.
max
(
q_values
))
q_values
=
[
np
.
exp
(
q_value
/
max_q_value
)
for
q_value
in
q_values
]
relevant
=
q_values
[
action_num
]
/
np
.
sum
(
q_values
)
q_values
=
[
np
.
exp
(
q_value
/
max_q_value
)
for
q_value
in
q_values
]
relevant
=
q_values
[
action_num
]
/
np
.
sum
(
q_values
)
return
relevant
class
DQNAgentOverOptions
(
DQNAgent
):
def
__init__
(
self
,
model
,
low_level_policies
,
policy
=
None
,
test_policy
=
None
,
enable_double_dqn
=
True
,
enable_dueling_network
=
False
,
dueling_type
=
'avg'
,
*
args
,
**
kwargs
):
super
(
DQNAgentOverOptions
,
self
).
__init__
(
model
,
policy
,
test_policy
,
enable_double_dqn
,
enable_dueling_network
,
dueling_type
,
*
args
,
**
kwargs
)
class
DQNAgentOverOptions
(
DQNAgent
):
def
__init__
(
self
,
model
,
low_level_policies
,
policy
=
None
,
test_policy
=
None
,
enable_double_dqn
=
True
,
enable_dueling_network
=
False
,
dueling_type
=
'avg'
,
*
args
,
**
kwargs
):
super
(
DQNAgentOverOptions
,
self
).
__init__
(
model
,
policy
,
test_policy
,
enable_double_dqn
,
enable_dueling_network
,
dueling_type
,
*
args
,
**
kwargs
)
self
.
low_level_policies
=
low_level_policies
if
low_level_policies
is
not
None
:
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
def
__get_invalid_node_indices
(
self
):
"""Returns a list of option indices that are invalid according to initiation conditions.
...
...
@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
q_values
[
node_index
]
=
-
np
.
inf
return
q_values
backends/learner_base.py
View file @
72e44d55
from
.
policy_base
import
PolicyBase
from
.policy_base
import
PolicyBase
import
numpy
as
np
...
...
@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase):
setattr
(
self
,
prop
,
kwargs
.
get
(
prop
,
default
))
def
train
(
self
,
env
,
nb_steps
=
50000
,
visualize
=
False
,
nb_max_episode_steps
=
200
):
env
,
nb_steps
=
50000
,
visualize
=
False
,
nb_max_episode_steps
=
200
):
"""Train the learning agent on the environment.
Args:
...
...
backends/manual_policy.py
View file @
72e44d55
from
.controller_base
import
ControllerBase
class
ManualPolicy
(
ControllerBase
):
"""Manual policy execution using nodes and edges."""
def
__init__
(
self
,
env
,
low_level_policies
,
transition_adj
,
start_node_alias
):
def
__init__
(
self
,
env
,
low_level_policies
,
transition_adj
,
start_node_alias
):
"""Constructor for manual policy execution.
Args:
...
...
@@ -13,7 +15,8 @@ class ManualPolicy(ControllerBase):
start_node: starting node
"""
super
(
ManualPolicy
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
ManualPolicy
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
adj
=
transition_adj
def
_transition
(
self
):
...
...
@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase):
new_node
=
self
.
_transition
()
if
new_node
is
not
None
:
self
.
current_node
=
new_node
\ No newline at end of file
self
.
current_node
=
new_node
backends/mcts_learner.py
View file @
72e44d55
...
...
@@ -2,6 +2,7 @@ from .controller_base import ControllerBase
import
numpy
as
np
import
pickle
class
MCTSLearner
(
ControllerBase
):
"""Monte Carlo Tree Search implementation using the UCB1 and
progressive widening approach as explained in Paxton et al (2017).
...
...
@@ -9,8 +10,8 @@ class MCTSLearner(ControllerBase):
_ucb_vals
=
set
()
def
__init__
(
self
,
env
,
low_level_policies
,
start_node_alias
,
max_depth
=
10
):
def
__init__
(
self
,
env
,
low_level_policies
,
start_node_alias
,
max_depth
=
10
):
"""Constructor for MCTSLearner.
Args:
...
...
@@ -22,10 +23,12 @@ class MCTSLearner(ControllerBase):
max_depth: max depth of the MCTS tree; default 10 levels
"""
super
(
MCTSLearner
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
MCTSLearner
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
controller_args_defaults
=
{
"predictor"
:
None
#P(s, o) learner class; forward pass should return the entire value from state s and option o
"predictor"
:
None
#P(s, o) learner class; forward pass should return the entire value from state s and option o
}
self
.
max_depth
=
max_depth
#: store current node alias
...
...
@@ -48,11 +51,17 @@ class MCTSLearner(ControllerBase):
# populate root node
root_node_num
,
root_node_info
=
self
.
_create_node
(
self
.
curr_node_alias
)
self
.
nodes
[
root_node_num
]
=
root_node_info
self
.
adj
[
root_node_num
]
=
set
()
# no children
self
.
adj
[
root_node_num
]
=
set
()
# no children
def
save_model
(
self
,
file_name
=
"mcts.pickle"
):
to_backup
=
{
'N'
:
self
.
N
,
'M'
:
self
.
M
,
'TR'
:
self
.
TR
,
'nodes'
:
self
.
nodes
,
'adj'
:
self
.
adj
,
'new_node_num'
:
self
.
new_node_num
}
to_backup
=
{
'N'
:
self
.
N
,
'M'
:
self
.
M
,
'TR'
:
self
.
TR
,
'nodes'
:
self
.
nodes
,
'adj'
:
self
.
adj
,
'new_node_num'
:
self
.
new_node_num
}
with
open
(
file_name
,
'wb'
)
as
handle
:
pickle
.
dump
(
to_backup
,
handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
...
...
@@ -96,9 +105,9 @@ class MCTSLearner(ControllerBase):
"""
dis_observation
=
''
for
item
in
observation
[
12
:
20
]:
if
type
(
item
)
==
bool
:
if
type
(
item
)
==
bool
:
dis_observation
+=
'1'
if
item
is
True
else
'0'
if
type
(
item
)
==
int
and
item
in
[
0
,
1
]:
if
type
(
item
)
==
int
and
item
in
[
0
,
1
]:
dis_observation
+=
str
(
item
)
env
=
self
.
current_node
.
env
...
...
@@ -163,7 +172,8 @@ class MCTSLearner(ControllerBase):
dis_observation
=
self
.
_to_discrete
(
observation
)
if
(
dis_observation
,
option
)
not
in
self
.
TR
:
self
.
TR
[(
dis_observation
,
option
)]
=
0
return
self
.
TR
[(
dis_observation
,
option
)]
/
(
1
+
self
.
_get_visitation_count
(
observation
,
option
))
return
self
.
TR
[(
dis_observation
,
option
)]
/
(
1
+
self
.
_get_visitation_count
(
observation
,
option
))
def
_get_possible_options
(
self
):
"""Returns a set of options that can be taken from the current node.
...
...
@@ -172,7 +182,7 @@ class MCTSLearner(ControllerBase):
"""
all_options
=
set
(
self
.
low_level_policies
.
keys
())
# Filter nodes whose initiation condition are true
filtered_options
=
set
()
for
option_alias
in
all_options
:
...
...
@@ -211,14 +221,15 @@ class MCTSLearner(ControllerBase):
Returns Q values for next nodes
"""
Q
=
{}
Q1
,
Q2
=
{},
{}
# debug
Q1
,
Q2
=
{},
{}
# debug
dis_observation
=
self
.
_to_discrete
(
observation
)
next_option_nums
=
self
.
adj
[
self
.
curr_node_num
]
for
next_option_num
in
next_option_nums
:
next_option
=
self
.
nodes
[
next_option_num
][
"policy"
]
Q1
[
next_option
]
=
(
self
.
_get_q_star
(
observation
,
next_option
)
+
200
)
/
400
Q1
[
next_option
]
=
(
self
.
_get_q_star
(
observation
,
next_option
)
+
200
)
/
400
Q
[(
dis_observation
,
next_option
)]
=
\
Q1
[
next_option
]
Q2
[
next_option
]
=
C
*
\
...
...
@@ -243,7 +254,7 @@ class MCTSLearner(ControllerBase):
relevant_rewards
=
[
value
for
key
,
value
in
self
.
TR
.
items
()
\
if
key
[
0
]
==
dis_observation
]
sum_rewards
=
np
.
sum
(
relevant_rewards
)
return
sum_rewards
/
(
1
+
self
.
_get_visitation_count
(
observation
))
return
sum_rewards
/
(
1
+
self
.
_get_visitation_count
(
observation
))
def
_select
(
self
,
observation
,
depth
=
0
,
visualize
=
False
):
"""MCTS selection function. For representation, we only use
...
...
@@ -266,7 +277,8 @@ class MCTSLearner(ControllerBase):
if
is_terminal
or
max_depth_reached
:
# print('MCTS went %d nodes deep' % depth)
return
self
.
_value
(
observation
),
0
# TODO: replace with final goal reward
return
self
.
_value
(
observation
),
0
# TODO: replace with final goal reward
Ns
=
self
.
_get_visitation_count
(
observation
)
Nchildren
=
len
(
self
.
adj
[
self
.
curr_node_num
])
...
...
@@ -288,19 +300,20 @@ class MCTSLearner(ControllerBase):
self
.
adj
[
self
.
curr_node_num
].
add
(
new_node_num
)
# Find o_star and do a transition, i.e. update curr_node
# Simulate / lookup; first change next
next_observation
,
episode_R
,
o_star
=
self
.
do_transition
(
observation
,
visualize
=
visualize
)
# Simulate / lookup; first change next
next_observation
,
episode_R
,
o_star
=
self
.
do_transition
(
observation
,
visualize
=
visualize
)
# Recursively select next node
remaining_v
,
all_ep_R
=
self
.
_select
(
next_observation
,
depth
+
1
,
visualize
=
visualize
)
remaining_v
,
all_ep_R
=
self
.
_select
(
next_observation
,
depth
+
1
,
visualize
=
visualize
)
# Update values
self
.
N
[
dis_observation
]
+=
1
self
.
M
[(
dis_observation
,
o_star
)]
+=
1
self
.
TR
[(
dis_observation
,
o_star
)]
+=
(
episode_R
+
remaining_v
)
return
self
.
_value
(
observation
),
all_ep_R
+
episode_R
return
self
.
_value
(
observation
),
all_ep_R
+
episode_R
def
traverse
(
self
,
observation
,
visualize
=
False
):
"""Do a complete traversal from root to leaf. Assumes the
...
...
@@ -368,4 +381,4 @@ class MCTSLearner(ControllerBase):
next_keys
,
next_values
=
list
(
Q1
.
keys
()),
list
(
Q1
.
values
())
o_star
=
next_keys
[
np
.
argmax
(
next_values
)]
print
(
Q1
)
return
o_star
\ No newline at end of file
return
o_star
backends/online_mcts_controller.py
View file @
72e44d55
...
...
@@ -3,6 +3,7 @@ from .mcts_learner import MCTSLearner
import
tqdm
import
numpy
as
np
class
OnlineMCTSController
(
ControllerBase
):
"""Online MCTS"""
...
...
@@ -13,12 +14,13 @@ class OnlineMCTSController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super
(
OnlineMCTSController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
OnlineMCTSController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
curr_node_alias
=
start_node_alias
self
.
controller_args_defaults
=
{
"predictor"
:
None
,
"max_depth"
:
5
,
# MCTS depth
"nb_traversals"
:
30
,
# MCTS traversals before decision
"max_depth"
:
5
,
# MCTS depth
"nb_traversals"
:
30
,
# MCTS traversals before decision
}
def
set_current_node
(
self
,
node_alias
):
...
...
@@ -48,12 +50,14 @@ class OnlineMCTSController(ControllerBase):
env_before_mcts
=
orig_env
.
copy
()
self
.
change_low_level_references
(
env_before_mcts
)
print
(
'Current Node: %s'
%
self
.
curr_node_alias
)
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,
self
.
curr_node_alias
)
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,
self
.
curr_node_alias
)
mcts
.
max_depth
=
self
.
max_depth
mcts
.
set_controller_args
(
predictor
=
self
.
predictor
)
mcts
.
set_controller_args
(
predictor
=
self
.
predictor
)
# Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for
num_epoch
in
range
(
self
.
nb_traversals
):
# tqdm.tqdm(range(self.nb_traversals)):
for
num_epoch
in
range
(
self
.
nb_traversals
):
# tqdm.tqdm(range(self.nb_traversals)):
mcts
.
curr_node_num
=
0
env_begin_epoch
=
env_before_mcts
.
copy
()
self
.
change_low_level_references
(
env_begin_epoch
)
...
...
@@ -63,6 +67,7 @@ class OnlineMCTSController(ControllerBase):
# Find the nodes from the root node
mcts
.
curr_node_num
=
0
print
(
'%s'
%
mcts
.
_to_discrete
(
self
.
env
.
get_features_tuple
()))
node_after_transition
=
mcts
.
get_best_node
(
self
.
env
.
get_features_tuple
(),
use_ucb
=
False
)
node_after_transition
=
mcts
.
get_best_node
(
self
.
env
.
get_features_tuple
(),
use_ucb
=
False
)
print
(
'MCTS suggested next option: %s'
%
node_after_transition
)
self
.
set_current_node
(
node_after_transition
)
\ No newline at end of file
self
.
set_current_node
(
node_after_transition
)
backends/policy_base.py
View file @
72e44d55
class
PolicyBase
:
"""Abstract policy base from which every policy backend is defined
and inherited."""
\ No newline at end of file
"""Abstract policy base from which every policy backend is defined
and inherited."""
backends/rl_controller.py
View file @
72e44d55
...
...
@@ -11,7 +11,8 @@ class RLController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super
(
RLController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
RLController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
self
.
trained_policy
=
None
self
.
node_terminal_state_reached
=
False
...
...
@@ -32,6 +33,8 @@ class RLController(ControllerBase):
if
self
.
trained_policy
is
None
:
raise
Exception
(
self
.
__class__
.
__name__
+
\
"trained_policy is not set. Use set_trained_policy()."
)
node_index_after_transition
=
self
.
trained_policy
(
self
.
env
.
get_features_tuple
())
self
.
set_current_node
(
self
.
low_level_policy_aliases
[
node_index_after_transition
])
self
.
node_terminal_state_reached
=
False
\ No newline at end of file
node_index_after_transition
=
self
.
trained_policy
(
self
.
env
.
get_features_tuple
())
self
.
set_current_node
(
self
.
low_level_policy_aliases
[
node_index_after_transition
])
self
.
node_terminal_state_reached
=
False
env/env_base.py
View file @
72e44d55
...
...
@@ -6,19 +6,22 @@ class GymCompliantEnvBase:
""" Gym compliant step function which
will be implemented in the subclass.
"""
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"step is not implemented."
)
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"step is not implemented."
)
def
reset
(
self
):
""" Gym compliant reset function which
will be implemented in the subclass.
"""
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"reset is not implemented."
)
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"reset is not implemented."
)
def
render
(
self
):
""" Gym compliant step function which
will be implemented in the subclass.
"""
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"render is not implemented."
)
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"render is not implemented."
)
class
EpisodicEnvBase
(
GymCompliantEnvBase
):
...
...
@@ -77,13 +80,16 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return
if
self
.
terminal_reward_type
==
'min'
:
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
min
(
self
.
_r_terminal
,
r_obs
)
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
min
(
self
.
_r_terminal
,
r_obs
)
elif
self
.
terminal_reward_type
==
'max'
:
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
max
(
self
.
_r_terminal
,
r_obs
)
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
max
(
self
.
_r_terminal
,
r_obs
)
elif
self
.
terminal_reward_type
==
'sum'
:
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
self
.
_r_terminal
+
r_obs
else
:
raise
AssertionError
(
"The terminal_reward_type has to be 'min', 'max', or 'sum'"
)
raise
AssertionError
(
"The terminal_reward_type has to be 'min', 'max', or 'sum'"
)
def
step
(
self
,
u
):
# the penalty is a negative reward.
...
...
@@ -118,7 +124,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
if
LTL_precondition
.
enabled
:
LTL_precondition
.
check_incremental
(
self
.
__mc_AP
)
if
LTL_precondition
.
result
==
Parser
.
FALSE
:
self
.
_terminal_reward_superposition
(
EpisodicEnvBase
.
_reward
(
LTL_precondition
.
penalty
))
self
.
_terminal_reward_superposition
(
EpisodicEnvBase
.
_reward
(
LTL_precondition
.
penalty
))
violate
=
True
info
[
'ltl_violation'
]
=
LTL_precondition
.
str
# print("\nViolation of \"" + LTL_precondition.str + "\"")
...
...
@@ -177,7 +184,6 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return
True
return
False
# TODO: replace these confusing methods reward and penalty, or not to use both reward and penalty for the property naming.
@
staticmethod
def
_reward
(
penalty
):
...
...
@@ -186,4 +192,3 @@ class EpisodicEnvBase(GymCompliantEnvBase):
@
staticmethod
def
_penalty
(
reward
):
return
None
if
reward
is
None
else
-
reward
env/simple_intersection/__init__.py
View file @
72e44d55
from
.simple_intersection_env
import
SimpleIntersectionEnv
\ No newline at end of file
from
.simple_intersection_env
import
SimpleIntersectionEnv
env/simple_intersection/features.py
View file @
72e44d55
...
...
@@ -34,7 +34,7 @@ con_ego_feature_dict = {
'pos_stop_region'
:
11
}
#: dis_ego_feature_dict contains all indexing information regarding
#: dis_ego_feature_dict contains all indexing information regarding
# each element of the ego vehicle's discrete feature vector.