Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
wise-move
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
2
Issues
2
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wise-lab
wise-move
Commits
72e44d55
Commit
72e44d55
authored
Nov 19, 2018
by
Ashish Gaurav
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
format using yapf
parent
e1fdb162
Changes
35
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
35 changed files
with
1757 additions
and
1325 deletions
+1757
-1325
backends/baselines_learner.py
backends/baselines_learner.py
+7
-3
backends/controller_base.py
backends/controller_base.py
+8
-3
backends/kerasrl_learner.py
backends/kerasrl_learner.py
+72
-43
backends/learner_base.py
backends/learner_base.py
+5
-5
backends/manual_policy.py
backends/manual_policy.py
+6
-3
backends/mcts_learner.py
backends/mcts_learner.py
+35
-22
backends/online_mcts_controller.py
backends/online_mcts_controller.py
+13
-8
backends/policy_base.py
backends/policy_base.py
+2
-2
backends/rl_controller.py
backends/rl_controller.py
+7
-4
env/env_base.py
env/env_base.py
+14
-9
env/simple_intersection/__init__.py
env/simple_intersection/__init__.py
+1
-1
env/simple_intersection/features.py
env/simple_intersection/features.py
+23
-15
env/simple_intersection/road_geokinemetry.py
env/simple_intersection/road_geokinemetry.py
+10
-8
env/simple_intersection/road_networks.py
env/simple_intersection/road_networks.py
+20
-10
env/simple_intersection/shapes.py
env/simple_intersection/shapes.py
+4
-2
env/simple_intersection/simple_intersection_env.py
env/simple_intersection/simple_intersection_env.py
+155
-85
env/simple_intersection/utilities.py
env/simple_intersection/utilities.py
+5
-5
env/simple_intersection/vehicle_networks.py
env/simple_intersection/vehicle_networks.py
+7
-3
env/simple_intersection/vehicles.py
env/simple_intersection/vehicles.py
+31
-18
high_level_policy_main.py
high_level_policy_main.py
+112
-50
low_level_policy_main.py
low_level_policy_main.py
+84
-38
mcts.py
mcts.py
+118
-55
model_checker/LTL_property_base.py
model_checker/LTL_property_base.py
+5
-5
model_checker/atomic_propositions_base.py
model_checker/atomic_propositions_base.py
+0
-1
model_checker/parser.py
model_checker/parser.py
+466
-465
model_checker/scanner.py
model_checker/scanner.py
+303
-308
model_checker/simple_intersection/AP_dict.py
model_checker/simple_intersection/AP_dict.py
+0
-1
model_checker/simple_intersection/LTL_test.py
model_checker/simple_intersection/LTL_test.py
+28
-14
model_checker/simple_intersection/__init__.py
model_checker/simple_intersection/__init__.py
+1
-1
model_checker/simple_intersection/classes.py
model_checker/simple_intersection/classes.py
+1
-1
options/options_loader.py
options/options_loader.py
+28
-18
options/simple_intersection/maneuver_base.py
options/simple_intersection/maneuver_base.py
+34
-16
options/simple_intersection/maneuvers.py
options/simple_intersection/maneuvers.py
+78
-59
options/simple_intersection/mcts_maneuvers.py
options/simple_intersection/mcts_maneuvers.py
+64
-39
ppo2_training.py
ppo2_training.py
+10
-5
No files found.
backends/baselines_learner.py
View file @
72e44d55
...
...
@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
self
.
log_path
=
log_path
self
.
env
=
DummyVecEnv
([
lambda
:
env
])
#PPO2 requried a vectorized environment for parallel training
self
.
env
=
DummyVecEnv
([
lambda
:
env
])
#PPO2 requried a vectorized environment for parallel training
self
.
agent_model
=
self
.
create_agent
(
policy
,
tensorboard
)
def
get_default_policy
(
self
):
...
...
@@ -52,7 +54,8 @@ class PPO2Agent(LearnerBase):
stable_baselines PPO2 object
"""
if
tensorboard
:
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
,
tensorboard_log
=
self
.
log_path
)
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
,
tensorboard_log
=
self
.
log_path
)
else
:
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
)
...
...
@@ -100,7 +103,8 @@ class PPO2Agent(LearnerBase):
episode_rewards
[
-
1
]
+=
rewards
[
0
]
if
dones
[
0
]
or
current_step
>
nb_max_episode_steps
:
obs
=
self
.
env
.
reset
()
print
(
"Episode "
,
current_episode
,
"reward: "
,
episode_rewards
[
-
1
])
print
(
"Episode "
,
current_episode
,
"reward: "
,
episode_rewards
[
-
1
])
episode_rewards
.
append
(
0.0
)
current_episode
+=
1
current_step
=
0
...
...
backends/controller_base.py
View file @
72e44d55
...
...
@@ -50,11 +50,13 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info
'''
def
step_current_node
(
self
,
visualize_low_level_steps
=
False
):
total_reward
=
0
self
.
node_terminal_state_reached
=
False
while
not
self
.
node_terminal_state_reached
:
observation
,
reward
,
terminal
,
info
=
self
.
low_level_step_current_node
()
observation
,
reward
,
terminal
,
info
=
self
.
low_level_step_current_node
(
)
if
visualize_low_level_steps
:
self
.
env
.
render
()
total_reward
+=
reward
...
...
@@ -70,9 +72,12 @@ class ControllerBase(PolicyBase):
Returns state after one step, step reward, episode_termination_flag, info
'''
def
low_level_step_current_node
(
self
):
u_ego
=
self
.
current_node
.
low_level_policy
(
self
.
current_node
.
get_reduced_features_tuple
())
u_ego
=
self
.
current_node
.
low_level_policy
(
self
.
current_node
.
get_reduced_features_tuple
())
feature
,
R
,
terminal
,
info
=
self
.
current_node
.
step
(
u_ego
)
self
.
node_terminal_state_reached
=
terminal
return
self
.
env
.
get_features_tuple
(),
R
,
self
.
env
.
termination_condition
,
info
return
self
.
env
.
get_features_tuple
(
),
R
,
self
.
env
.
termination_condition
,
info
backends/kerasrl_learner.py
View file @
72e44d55
...
...
@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
"oup_mu"
:
0
,
# OrnsteinUhlenbeckProcess mu
"oup_sigma"
:
1
,
# OrnsteinUhlenbeckProcess sigma
"oup_sigma_min"
:
0.5
,
# OrnsteinUhlenbeckProcess sigma min
"oup_annealing_steps"
:
500000
,
# OrnsteinUhlenbeckProcess n-step annealing
"oup_annealing_steps"
:
500000
,
# OrnsteinUhlenbeckProcess n-step annealing
"nb_steps_warmup_critic"
:
100
,
# steps for critic to warmup
"nb_steps_warmup_actor"
:
100
,
# steps for actor to warmup
"target_model_update"
:
1e-3
# target model update frequency
...
...
@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase):
target_model_update
=
1e-3
)
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
agent
.
compile
([
Adam
(
lr
=
self
.
lr
*
1e-2
,
clipnorm
=
1.
),
Adam
(
lr
=
self
.
lr
,
clipnorm
=
1.
)],
metrics
=
[
'mae'
])
agent
.
compile
(
[
Adam
(
lr
=
self
.
lr
*
1e-2
,
clipnorm
=
1.
),
Adam
(
lr
=
self
.
lr
,
clipnorm
=
1.
)
],
metrics
=
[
'mae'
])
return
agent
def
train
(
self
,
env
,
nb_steps
=
1000000
,
visualize
=
False
,
verbose
=
1
,
log_interval
=
10000
,
nb_max_episode_steps
=
200
,
model_checkpoints
=
False
,
checkpoint_interval
=
100000
,
tensorboard
=
False
):
env
,
nb_steps
=
1000000
,
visualize
=
False
,
verbose
=
1
,
log_interval
=
10000
,
nb_max_episode_steps
=
200
,
model_checkpoints
=
False
,
checkpoint_interval
=
100000
,
tensorboard
=
False
):
callbacks
=
[]
if
model_checkpoints
:
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)]
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)
]
if
tensorboard
:
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
...
...
@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase):
Returns:
KerasRL DQN object
"""
agent
=
DQNAgentOverOptions
(
model
=
model
,
low_level_policies
=
self
.
low_level_policies
,
nb_actions
=
self
.
nb_actions
,
memory
=
memory
,
nb_steps_warmup
=
self
.
nb_steps_warmup
,
target_model_update
=
self
.
target_model_update
,
policy
=
policy
,
enable_dueling_network
=
True
)
agent
=
DQNAgentOverOptions
(
model
=
model
,
low_level_policies
=
self
.
low_level_policies
,
nb_actions
=
self
.
nb_actions
,
memory
=
memory
,
nb_steps_warmup
=
self
.
nb_steps_warmup
,
target_model_update
=
self
.
target_model_update
,
policy
=
policy
,
enable_dueling_network
=
True
)
agent
.
compile
(
Adam
(
lr
=
self
.
lr
),
metrics
=
[
'mae'
])
return
agent
def
train
(
self
,
env
,
nb_steps
=
1000000
,
visualize
=
False
,
nb_max_episode_steps
=
200
,
tensorboard
=
False
,
model_checkpoints
=
False
,
checkpoint_interval
=
10000
):
env
,
nb_steps
=
1000000
,
visualize
=
False
,
nb_max_episode_steps
=
200
,
tensorboard
=
False
,
model_checkpoints
=
False
,
checkpoint_interval
=
10000
):
callbacks
=
[]
if
model_checkpoints
:
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)]
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)
]
if
tensorboard
:
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
...
...
@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
nb_episodes
=
5
,
visualize
=
True
,
nb_max_episode_steps
=
400
,
success_reward_threshold
=
100
):
success_reward_threshold
=
100
):
print
(
"Testing for {} episodes"
.
format
(
nb_episodes
))
success_count
=
0
...
...
@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
env
.
reset
()
if
episode_reward
>=
success_reward_threshold
:
success_count
+=
1
print
(
"Episode {}: steps:{}, reward:{}"
.
format
(
n
+
1
,
step
,
episode_reward
))
print
(
"Episode {}: steps:{}, reward:{}"
.
format
(
n
+
1
,
step
,
episode_reward
))
print
(
"
\n
Policy succeeded {} times!"
.
format
(
success_count
))
print
(
"Failures due to:"
)
print
(
termination_reason_counter
)
print
(
"
\n
Policy succeeded {} times!"
.
format
(
success_count
))
print
(
"Failures due to:"
)
print
(
termination_reason_counter
)
return
[
success_count
,
termination_reason_counter
]
return
[
success_count
,
termination_reason_counter
]
def
load_model
(
self
,
file_name
=
"test_weights.h5f"
):
self
.
agent_model
.
load_weights
(
file_name
)
...
...
@@ -377,27 +396,39 @@ class DQNLearner(LearnerBase):
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action
]
def
get_q_value_using_option_alias
(
self
,
observation
,
option_alias
):
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action_num
]
def
get_softq_value_using_option_alias
(
self
,
observation
,
option_alias
):
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
q_values
=
self
.
agent_model
.
get_modified_q_values
(
observation
)
max_q_value
=
np
.
abs
(
np
.
max
(
q_values
))
q_values
=
[
np
.
exp
(
q_value
/
max_q_value
)
for
q_value
in
q_values
]
relevant
=
q_values
[
action_num
]
/
np
.
sum
(
q_values
)
q_values
=
[
np
.
exp
(
q_value
/
max_q_value
)
for
q_value
in
q_values
]
relevant
=
q_values
[
action_num
]
/
np
.
sum
(
q_values
)
return
relevant
class
DQNAgentOverOptions
(
DQNAgent
):
def
__init__
(
self
,
model
,
low_level_policies
,
policy
=
None
,
test_policy
=
None
,
enable_double_dqn
=
True
,
enable_dueling_network
=
False
,
dueling_type
=
'avg'
,
*
args
,
**
kwargs
):
super
(
DQNAgentOverOptions
,
self
).
__init__
(
model
,
policy
,
test_policy
,
enable_double_dqn
,
enable_dueling_network
,
dueling_type
,
*
args
,
**
kwargs
)
class
DQNAgentOverOptions
(
DQNAgent
):
def
__init__
(
self
,
model
,
low_level_policies
,
policy
=
None
,
test_policy
=
None
,
enable_double_dqn
=
True
,
enable_dueling_network
=
False
,
dueling_type
=
'avg'
,
*
args
,
**
kwargs
):
super
(
DQNAgentOverOptions
,
self
).
__init__
(
model
,
policy
,
test_policy
,
enable_double_dqn
,
enable_dueling_network
,
dueling_type
,
*
args
,
**
kwargs
)
self
.
low_level_policies
=
low_level_policies
if
low_level_policies
is
not
None
:
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
def
__get_invalid_node_indices
(
self
):
"""Returns a list of option indices that are invalid according to initiation conditions.
...
...
@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
q_values
[
node_index
]
=
-
np
.
inf
return
q_values
backends/learner_base.py
View file @
72e44d55
from
.
policy_base
import
PolicyBase
from
.policy_base
import
PolicyBase
import
numpy
as
np
...
...
@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase):
setattr
(
self
,
prop
,
kwargs
.
get
(
prop
,
default
))
def
train
(
self
,
env
,
nb_steps
=
50000
,
visualize
=
False
,
nb_max_episode_steps
=
200
):
env
,
nb_steps
=
50000
,
visualize
=
False
,
nb_max_episode_steps
=
200
):
"""Train the learning agent on the environment.
Args:
...
...
backends/manual_policy.py
View file @
72e44d55
from
.controller_base
import
ControllerBase
class
ManualPolicy
(
ControllerBase
):
"""Manual policy execution using nodes and edges."""
def
__init__
(
self
,
env
,
low_level_policies
,
transition_adj
,
start_node_alias
):
def
__init__
(
self
,
env
,
low_level_policies
,
transition_adj
,
start_node_alias
):
"""Constructor for manual policy execution.
Args:
...
...
@@ -13,7 +15,8 @@ class ManualPolicy(ControllerBase):
start_node: starting node
"""
super
(
ManualPolicy
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
ManualPolicy
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
adj
=
transition_adj
def
_transition
(
self
):
...
...
@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase):
new_node
=
self
.
_transition
()
if
new_node
is
not
None
:
self
.
current_node
=
new_node
\ No newline at end of file
self
.
current_node
=
new_node
backends/mcts_learner.py
View file @
72e44d55
...
...
@@ -2,6 +2,7 @@ from .controller_base import ControllerBase
import
numpy
as
np
import
pickle
class
MCTSLearner
(
ControllerBase
):
"""Monte Carlo Tree Search implementation using the UCB1 and
progressive widening approach as explained in Paxton et al (2017).
...
...
@@ -9,8 +10,8 @@ class MCTSLearner(ControllerBase):
_ucb_vals
=
set
()
def
__init__
(
self
,
env
,
low_level_policies
,
start_node_alias
,
max_depth
=
10
):
def
__init__
(
self
,
env
,
low_level_policies
,
start_node_alias
,
max_depth
=
10
):
"""Constructor for MCTSLearner.
Args:
...
...
@@ -22,10 +23,12 @@ class MCTSLearner(ControllerBase):
max_depth: max depth of the MCTS tree; default 10 levels
"""
super
(
MCTSLearner
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
MCTSLearner
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
controller_args_defaults
=
{
"predictor"
:
None
#P(s, o) learner class; forward pass should return the entire value from state s and option o
"predictor"
:
None
#P(s, o) learner class; forward pass should return the entire value from state s and option o
}
self
.
max_depth
=
max_depth
#: store current node alias
...
...
@@ -48,11 +51,17 @@ class MCTSLearner(ControllerBase):
# populate root node
root_node_num
,
root_node_info
=
self
.
_create_node
(
self
.
curr_node_alias
)
self
.
nodes
[
root_node_num
]
=
root_node_info
self
.
adj
[
root_node_num
]
=
set
()
# no children
self
.
adj
[
root_node_num
]
=
set
()
# no children
def
save_model
(
self
,
file_name
=
"mcts.pickle"
):
to_backup
=
{
'N'
:
self
.
N
,
'M'
:
self
.
M
,
'TR'
:
self
.
TR
,
'nodes'
:
self
.
nodes
,
'adj'
:
self
.
adj
,
'new_node_num'
:
self
.
new_node_num
}
to_backup
=
{
'N'
:
self
.
N
,
'M'
:
self
.
M
,
'TR'
:
self
.
TR
,
'nodes'
:
self
.
nodes
,
'adj'
:
self
.
adj
,
'new_node_num'
:
self
.
new_node_num
}
with
open
(
file_name
,
'wb'
)
as
handle
:
pickle
.
dump
(
to_backup
,
handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
...
...
@@ -96,9 +105,9 @@ class MCTSLearner(ControllerBase):
"""
dis_observation
=
''
for
item
in
observation
[
12
:
20
]:
if
type
(
item
)
==
bool
:
if
type
(
item
)
==
bool
:
dis_observation
+=
'1'
if
item
is
True
else
'0'
if
type
(
item
)
==
int
and
item
in
[
0
,
1
]:
if
type
(
item
)
==
int
and
item
in
[
0
,
1
]:
dis_observation
+=
str
(
item
)
env
=
self
.
current_node
.
env
...
...
@@ -163,7 +172,8 @@ class MCTSLearner(ControllerBase):
dis_observation
=
self
.
_to_discrete
(
observation
)
if
(
dis_observation
,
option
)
not
in
self
.
TR
:
self
.
TR
[(
dis_observation
,
option
)]
=
0
return
self
.
TR
[(
dis_observation
,
option
)]
/
(
1
+
self
.
_get_visitation_count
(
observation
,
option
))
return
self
.
TR
[(
dis_observation
,
option
)]
/
(
1
+
self
.
_get_visitation_count
(
observation
,
option
))
def
_get_possible_options
(
self
):
"""Returns a set of options that can be taken from the current node.
...
...
@@ -172,7 +182,7 @@ class MCTSLearner(ControllerBase):
"""
all_options
=
set
(
self
.
low_level_policies
.
keys
())
# Filter nodes whose initiation condition are true
filtered_options
=
set
()
for
option_alias
in
all_options
:
...
...
@@ -211,14 +221,15 @@ class MCTSLearner(ControllerBase):
Returns Q values for next nodes
"""
Q
=
{}
Q1
,
Q2
=
{},
{}
# debug
Q1
,
Q2
=
{},
{}
# debug
dis_observation
=
self
.
_to_discrete
(
observation
)
next_option_nums
=
self
.
adj
[
self
.
curr_node_num
]
for
next_option_num
in
next_option_nums
:
next_option
=
self
.
nodes
[
next_option_num
][
"policy"
]
Q1
[
next_option
]
=
(
self
.
_get_q_star
(
observation
,
next_option
)
+
200
)
/
400
Q1
[
next_option
]
=
(
self
.
_get_q_star
(
observation
,
next_option
)
+
200
)
/
400
Q
[(
dis_observation
,
next_option
)]
=
\
Q1
[
next_option
]
Q2
[
next_option
]
=
C
*
\
...
...
@@ -243,7 +254,7 @@ class MCTSLearner(ControllerBase):
relevant_rewards
=
[
value
for
key
,
value
in
self
.
TR
.
items
()
\
if
key
[
0
]
==
dis_observation
]
sum_rewards
=
np
.
sum
(
relevant_rewards
)
return
sum_rewards
/
(
1
+
self
.
_get_visitation_count
(
observation
))
return
sum_rewards
/
(
1
+
self
.
_get_visitation_count
(
observation
))
def
_select
(
self
,
observation
,
depth
=
0
,
visualize
=
False
):
"""MCTS selection function. For representation, we only use
...
...
@@ -266,7 +277,8 @@ class MCTSLearner(ControllerBase):
if
is_terminal
or
max_depth_reached
:
# print('MCTS went %d nodes deep' % depth)
return
self
.
_value
(
observation
),
0
# TODO: replace with final goal reward
return
self
.
_value
(
observation
),
0
# TODO: replace with final goal reward
Ns
=
self
.
_get_visitation_count
(
observation
)
Nchildren
=
len
(
self
.
adj
[
self
.
curr_node_num
])
...
...
@@ -288,19 +300,20 @@ class MCTSLearner(ControllerBase):
self
.
adj
[
self
.
curr_node_num
].
add
(
new_node_num
)
# Find o_star and do a transition, i.e. update curr_node
# Simulate / lookup; first change next
next_observation
,
episode_R
,
o_star
=
self
.
do_transition
(
observation
,
visualize
=
visualize
)
# Simulate / lookup; first change next
next_observation
,
episode_R
,
o_star
=
self
.
do_transition
(
observation
,
visualize
=
visualize
)
# Recursively select next node
remaining_v
,
all_ep_R
=
self
.
_select
(
next_observation
,
depth
+
1
,
visualize
=
visualize
)
remaining_v
,
all_ep_R
=
self
.
_select
(
next_observation
,
depth
+
1
,
visualize
=
visualize
)
# Update values
self
.
N
[
dis_observation
]
+=
1
self
.
M
[(
dis_observation
,
o_star
)]
+=
1
self
.
TR
[(
dis_observation
,
o_star
)]
+=
(
episode_R
+
remaining_v
)
return
self
.
_value
(
observation
),
all_ep_R
+
episode_R
return
self
.
_value
(
observation
),
all_ep_R
+
episode_R
def
traverse
(
self
,
observation
,
visualize
=
False
):
"""Do a complete traversal from root to leaf. Assumes the
...
...
@@ -368,4 +381,4 @@ class MCTSLearner(ControllerBase):
next_keys
,
next_values
=
list
(
Q1
.
keys
()),
list
(
Q1
.
values
())
o_star
=
next_keys
[
np
.
argmax
(
next_values
)]
print
(
Q1
)
return
o_star
\ No newline at end of file
return
o_star
backends/online_mcts_controller.py
View file @
72e44d55
...
...
@@ -3,6 +3,7 @@ from .mcts_learner import MCTSLearner
import
tqdm
import
numpy
as
np
class
OnlineMCTSController
(
ControllerBase
):
"""Online MCTS"""
...
...
@@ -13,12 +14,13 @@ class OnlineMCTSController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super
(
OnlineMCTSController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
OnlineMCTSController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
curr_node_alias
=
start_node_alias
self
.
controller_args_defaults
=
{
"predictor"
:
None
,
"max_depth"
:
5
,
# MCTS depth
"nb_traversals"
:
30
,
# MCTS traversals before decision
"max_depth"
:
5
,
# MCTS depth
"nb_traversals"
:
30
,
# MCTS traversals before decision
}
def
set_current_node
(
self
,
node_alias
):
...
...
@@ -48,12 +50,14 @@ class OnlineMCTSController(ControllerBase):
env_before_mcts
=
orig_env
.
copy
()
self
.
change_low_level_references
(
env_before_mcts
)
print
(
'Current Node: %s'
%
self
.
curr_node_alias
)
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,
self
.
curr_node_alias
)
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,
self
.
curr_node_alias
)
mcts
.
max_depth
=
self
.
max_depth
mcts
.
set_controller_args
(
predictor
=
self
.
predictor
)
mcts
.
set_controller_args
(
predictor
=
self
.
predictor
)
# Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for
num_epoch
in
range
(
self
.
nb_traversals
):
# tqdm.tqdm(range(self.nb_traversals)):
for
num_epoch
in
range
(
self
.
nb_traversals
):
# tqdm.tqdm(range(self.nb_traversals)):
mcts
.
curr_node_num
=
0
env_begin_epoch
=
env_before_mcts
.
copy
()
self
.
change_low_level_references
(
env_begin_epoch
)
...
...
@@ -63,6 +67,7 @@ class OnlineMCTSController(ControllerBase):
# Find the nodes from the root node
mcts
.
curr_node_num
=
0
print
(
'%s'
%
mcts
.
_to_discrete
(
self
.
env
.
get_features_tuple
()))
node_after_transition
=
mcts
.
get_best_node
(
self
.
env
.
get_features_tuple
(),
use_ucb
=
False
)
node_after_transition
=
mcts
.
get_best_node
(
self
.
env
.
get_features_tuple
(),
use_ucb
=
False
)
print
(
'MCTS suggested next option: %s'
%
node_after_transition
)
self
.
set_current_node
(
node_after_transition
)
\ No newline at end of file
self
.
set_current_node
(
node_after_transition
)
backends/policy_base.py
View file @
72e44d55
class
PolicyBase
:
"""Abstract policy base from which every policy backend is defined
and inherited."""
\ No newline at end of file
"""Abstract policy base from which every policy backend is defined
and inherited."""
backends/rl_controller.py
View file @
72e44d55
...
...
@@ -11,7 +11,8 @@ class RLController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super
(
RLController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
RLController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
self
.
trained_policy
=
None
self
.
node_terminal_state_reached
=
False
...
...
@@ -32,6 +33,8 @@ class RLController(ControllerBase):
if
self
.
trained_policy
is
None
:
raise
Exception
(
self
.
__class__
.
__name__
+
\
"trained_policy is not set. Use set_trained_policy()."
)
node_index_after_transition
=
self
.
trained_policy
(
self
.
env
.
get_features_tuple
())
self
.
set_current_node
(
self
.
low_level_policy_aliases
[
node_index_after_transition
])
self
.
node_terminal_state_reached
=
False
\ No newline at end of file
node_index_after_transition
=
self
.
trained_policy
(
self
.
env
.
get_features_tuple
())
self
.
set_current_node
(
self
.
low_level_policy_aliases
[
node_index_after_transition
])
self
.
node_terminal_state_reached
=
False
env/env_base.py
View file @
72e44d55
...
...
@@ -6,19 +6,22 @@ class GymCompliantEnvBase:
""" Gym compliant step function which
will be implemented in the subclass.
"""
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"step is not implemented."
)
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"step is not implemented."
)
def
reset
(
self
):
""" Gym compliant reset function which
will be implemented in the subclass.
"""
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"reset is not implemented."
)
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"reset is not implemented."
)
def
render
(
self
):
""" Gym compliant step function which
will be implemented in the subclass.
"""
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"render is not implemented."
)
raise
NotImplemented
(
self
.
__class__
.
__name__
+
"render is not implemented."
)
class
EpisodicEnvBase
(
GymCompliantEnvBase
):
...
...
@@ -77,13 +80,16 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return
if
self
.
terminal_reward_type
==
'min'
:
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
min
(
self
.
_r_terminal
,
r_obs
)
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
min
(
self
.
_r_terminal
,
r_obs
)
elif
self
.
terminal_reward_type
==
'max'
:
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
max
(
self
.
_r_terminal
,
r_obs
)
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
max
(
self
.
_r_terminal
,
r_obs
)
elif
self
.
terminal_reward_type
==
'sum'
:
self
.
_r_terminal
=
r_obs
if
self
.
_r_terminal
is
None
else
self
.
_r_terminal
+
r_obs
else
:
raise
AssertionError
(
"The terminal_reward_type has to be 'min', 'max', or 'sum'"
)
raise
AssertionError
(
"The terminal_reward_type has to be 'min', 'max', or 'sum'"
)
def
step
(
self
,
u
):
# the penalty is a negative reward.
...
...
@@ -118,7 +124,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
if
LTL_precondition
.
enabled
:
LTL_precondition
.
check_incremental
(
self
.
__mc_AP
)
if
LTL_precondition
.
result
==
Parser
.
FALSE
:
self
.
_terminal_reward_superposition
(
EpisodicEnvBase
.
_reward
(
LTL_precondition
.
penalty
))
self
.
_terminal_reward_superposition
(
EpisodicEnvBase
.
_reward
(
LTL_precondition
.
penalty
))
violate
=
True
info
[
'ltl_violation'
]
=
LTL_precondition
.
str
# print("\nViolation of \"" + LTL_precondition.str + "\"")
...
...
@@ -177,7 +184,6 @@ class EpisodicEnvBase(GymCompliantEnvBase):
return
True
return
False
# TODO: replace these confusing methods reward and penalty, or not to use both reward and penalty for the property naming.
@
staticmethod
def
_reward
(
penalty
):
...
...
@@ -186,4 +192,3 @@ class EpisodicEnvBase(GymCompliantEnvBase):
@
staticmethod
def
_penalty
(
reward
):
return
None
if
reward
is
None
else
-
reward
env/simple_intersection/__init__.py
View file @
72e44d55
from
.simple_intersection_env
import
SimpleIntersectionEnv
\ No newline at end of file
from
.simple_intersection_env
import
SimpleIntersectionEnv
env/simple_intersection/features.py
View file @
72e44d55
...
...
@@ -34,7 +34,7 @@ con_ego_feature_dict = {
'pos_stop_region'
:
11
}
#: dis_ego_feature_dict contains all indexing information regarding
#: dis_ego_feature_dict contains all indexing information regarding
# each element of the ego vehicle's discrete feature vector.
#
# * not_in_stop_region: True if the ego is in stop region;
...
...
@@ -69,7 +69,8 @@ ego_feature_dict = dict()
for
key
in
con_ego_feature_dict
.
keys
():
ego_feature_dict
[
key
]
=
con_ego_feature_dict
[
key
]
for
key
in
dis_ego_feature_dict
.
keys
():
ego_feature_dict
[
key
]
=
dis_ego_feature_dict
[
key
]
+
len
(
con_ego_feature_dict
)
ego_feature_dict
[
key
]
=
dis_ego_feature_dict
[
key
]
+
len
(
con_ego_feature_dict
)
ego_feature_len
=
len
(
ego_feature_dict
)
other_veh_feature_len
=
len
(
other_veh_feature_dict
)
...
...
@@ -81,8 +82,8 @@ def extract_ego_features(features_tuple, *args):
def
extract_other_veh_features
(
features_tuple
,
veh_index
,
*
args
):
return
tuple
(
features_tuple
[
ego_feature_len
+
(
veh_index
-
1
)
*
other_veh_feature_len
+
other_veh_feature_dict
[
key
]]
for
key
in
args
)
(
veh_index
-
1
)
*
other_veh_feature_len