Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
wise-move
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
2
Issues
2
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wise-lab
wise-move
Commits
f2171d2c
Commit
f2171d2c
authored
Nov 19, 2018
by
Aravind Balakrishnan
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'formatting' into 'master'
Formatting See merge request
!3
parents
e1fdb162
7abc600f
Changes
37
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
37 changed files
with
1933 additions
and
1487 deletions
+1933
-1487
.gitignore
.gitignore
+3
-0
backends/baselines_learner.py
backends/baselines_learner.py
+9
-6
backends/controller_base.py
backends/controller_base.py
+14
-8
backends/kerasrl_learner.py
backends/kerasrl_learner.py
+74
-45
backends/learner_base.py
backends/learner_base.py
+5
-5
backends/manual_policy.py
backends/manual_policy.py
+8
-5
backends/mcts_learner.py
backends/mcts_learner.py
+61
-49
backends/online_mcts_controller.py
backends/online_mcts_controller.py
+14
-9
backends/policy_base.py
backends/policy_base.py
+2
-2
backends/rl_controller.py
backends/rl_controller.py
+7
-4
env/env_base.py
env/env_base.py
+28
-26
env/road_env.py
env/road_env.py
+3
-3
env/simple_intersection/__init__.py
env/simple_intersection/__init__.py
+1
-1
env/simple_intersection/features.py
env/simple_intersection/features.py
+24
-18
env/simple_intersection/road_geokinemetry.py
env/simple_intersection/road_geokinemetry.py
+10
-8
env/simple_intersection/road_networks.py
env/simple_intersection/road_networks.py
+20
-10
env/simple_intersection/shapes.py
env/simple_intersection/shapes.py
+5
-3
env/simple_intersection/simple_intersection_env.py
env/simple_intersection/simple_intersection_env.py
+173
-106
env/simple_intersection/utilities.py
env/simple_intersection/utilities.py
+17
-11
env/simple_intersection/vehicle_networks.py
env/simple_intersection/vehicle_networks.py
+7
-3
env/simple_intersection/vehicles.py
env/simple_intersection/vehicles.py
+31
-18
high_level_policy_main.py
high_level_policy_main.py
+114
-52
low_level_policy_main.py
low_level_policy_main.py
+86
-40
mcts.py
mcts.py
+122
-59
model_checker/LTL_property_base.py
model_checker/LTL_property_base.py
+17
-15
model_checker/atomic_propositions_base.py
model_checker/atomic_propositions_base.py
+6
-6
model_checker/parser.py
model_checker/parser.py
+480
-475
model_checker/scanner.py
model_checker/scanner.py
+303
-308
model_checker/simple_intersection/AP_dict.py
model_checker/simple_intersection/AP_dict.py
+0
-1
model_checker/simple_intersection/LTL_test.py
model_checker/simple_intersection/LTL_test.py
+28
-14
model_checker/simple_intersection/__init__.py
model_checker/simple_intersection/__init__.py
+1
-1
model_checker/simple_intersection/classes.py
model_checker/simple_intersection/classes.py
+6
-8
options/options_loader.py
options/options_loader.py
+43
-31
options/simple_intersection/maneuver_base.py
options/simple_intersection/maneuver_base.py
+57
-33
options/simple_intersection/maneuvers.py
options/simple_intersection/maneuvers.py
+78
-59
options/simple_intersection/mcts_maneuvers.py
options/simple_intersection/mcts_maneuvers.py
+64
-39
ppo2_training.py
ppo2_training.py
+12
-6
No files found.
.gitignore
0 → 100644
View file @
f2171d2c
*.pyc
*.py~
__pycache__
backends/baselines_learner.py
View file @
f2171d2c
...
@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
...
@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
self
.
log_path
=
log_path
self
.
log_path
=
log_path
self
.
env
=
DummyVecEnv
([
lambda
:
env
])
#PPO2 requried a vectorized environment for parallel training
self
.
env
=
DummyVecEnv
([
lambda
:
env
])
#PPO2 requried a vectorized environment for parallel training
self
.
agent_model
=
self
.
create_agent
(
policy
,
tensorboard
)
self
.
agent_model
=
self
.
create_agent
(
policy
,
tensorboard
)
def
get_default_policy
(
self
):
def
get_default_policy
(
self
):
...
@@ -46,13 +48,13 @@ class PPO2Agent(LearnerBase):
...
@@ -46,13 +48,13 @@ class PPO2Agent(LearnerBase):
return
MlpPolicy
return
MlpPolicy
def
create_agent
(
self
,
policy
,
tensorboard
):
def
create_agent
(
self
,
policy
,
tensorboard
):
"""Creates a PPO agent
"""Creates a PPO agent
.
Returns:
Returns: stable_baselines PPO2 object
stable_baselines PPO2 object
"""
"""
if
tensorboard
:
if
tensorboard
:
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
,
tensorboard_log
=
self
.
log_path
)
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
,
tensorboard_log
=
self
.
log_path
)
else
:
else
:
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
)
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
)
...
@@ -100,7 +102,8 @@ class PPO2Agent(LearnerBase):
...
@@ -100,7 +102,8 @@ class PPO2Agent(LearnerBase):
episode_rewards
[
-
1
]
+=
rewards
[
0
]
episode_rewards
[
-
1
]
+=
rewards
[
0
]
if
dones
[
0
]
or
current_step
>
nb_max_episode_steps
:
if
dones
[
0
]
or
current_step
>
nb_max_episode_steps
:
obs
=
self
.
env
.
reset
()
obs
=
self
.
env
.
reset
()
print
(
"Episode "
,
current_episode
,
"reward: "
,
episode_rewards
[
-
1
])
print
(
"Episode "
,
current_episode
,
"reward: "
,
episode_rewards
[
-
1
])
episode_rewards
.
append
(
0.0
)
episode_rewards
.
append
(
0.0
)
current_episode
+=
1
current_episode
+=
1
current_step
=
0
current_step
=
0
...
...
backends/controller_base.py
View file @
f2171d2c
...
@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase):
...
@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase):
setattr
(
self
,
prop
,
kwargs
.
get
(
prop
,
default
))
setattr
(
self
,
prop
,
kwargs
.
get
(
prop
,
default
))
def
can_transition
(
self
):
def
can_transition
(
self
):
"""Returns boolean signifying whether we can transition. To be
"""Returns boolean signifying whether we can transition.
implemented in subclass.
To be implemented in subclass.
"""
"""
raise
NotImplemented
(
self
.
__class__
.
__name__
+
\
raise
NotImplemented
(
self
.
__class__
.
__name__
+
\
"can_transition is not implemented."
)
"can_transition is not implemented."
)
def
do_transition
(
self
,
observation
):
def
do_transition
(
self
,
observation
):
"""Do a transition, assuming we can transition. To be
"""Do a transition, assuming we can transition. To be
implemented in
implemented in
subclass.
subclass.
Args:
Args:
observation: final observation from episodic step
observation: final observation from episodic step
...
@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase):
...
@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase):
"do_transition is not implemented."
)
"do_transition is not implemented."
)
def
set_current_node
(
self
,
node_alias
):
def
set_current_node
(
self
,
node_alias
):
"""Sets the current node which is being executed
"""Sets the current node which is being executed
.
Args:
Args:
node: node alias of the node to be set
node: node alias of the node to be set
...
@@ -50,11 +51,13 @@ class ControllerBase(PolicyBase):
...
@@ -50,11 +51,13 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info
Returns state at end of node execution, total reward, epsiode_termination_flag, info
'''
'''
def
step_current_node
(
self
,
visualize_low_level_steps
=
False
):
def
step_current_node
(
self
,
visualize_low_level_steps
=
False
):
total_reward
=
0
total_reward
=
0
self
.
node_terminal_state_reached
=
False
self
.
node_terminal_state_reached
=
False
while
not
self
.
node_terminal_state_reached
:
while
not
self
.
node_terminal_state_reached
:
observation
,
reward
,
terminal
,
info
=
self
.
low_level_step_current_node
()
observation
,
reward
,
terminal
,
info
=
self
.
low_level_step_current_node
(
)
if
visualize_low_level_steps
:
if
visualize_low_level_steps
:
self
.
env
.
render
()
self
.
env
.
render
()
total_reward
+=
reward
total_reward
+=
reward
...
@@ -70,9 +73,12 @@ class ControllerBase(PolicyBase):
...
@@ -70,9 +73,12 @@ class ControllerBase(PolicyBase):
Returns state after one step, step reward, episode_termination_flag, info
Returns state after one step, step reward, episode_termination_flag, info
'''
'''
def
low_level_step_current_node
(
self
):
def
low_level_step_current_node
(
self
):
u_ego
=
self
.
current_node
.
low_level_policy
(
self
.
current_node
.
get_reduced_features_tuple
())
u_ego
=
self
.
current_node
.
low_level_policy
(
self
.
current_node
.
get_reduced_features_tuple
())
feature
,
R
,
terminal
,
info
=
self
.
current_node
.
step
(
u_ego
)
feature
,
R
,
terminal
,
info
=
self
.
current_node
.
step
(
u_ego
)
self
.
node_terminal_state_reached
=
terminal
self
.
node_terminal_state_reached
=
terminal
return
self
.
env
.
get_features_tuple
(),
R
,
self
.
env
.
termination_condition
,
info
return
self
.
env
.
get_features_tuple
(
),
R
,
self
.
env
.
termination_condition
,
info
backends/kerasrl_learner.py
View file @
f2171d2c
...
@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
...
@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
"oup_mu"
:
0
,
# OrnsteinUhlenbeckProcess mu
"oup_mu"
:
0
,
# OrnsteinUhlenbeckProcess mu
"oup_sigma"
:
1
,
# OrnsteinUhlenbeckProcess sigma
"oup_sigma"
:
1
,
# OrnsteinUhlenbeckProcess sigma
"oup_sigma_min"
:
0.5
,
# OrnsteinUhlenbeckProcess sigma min
"oup_sigma_min"
:
0.5
,
# OrnsteinUhlenbeckProcess sigma min
"oup_annealing_steps"
:
500000
,
# OrnsteinUhlenbeckProcess n-step annealing
"oup_annealing_steps"
:
500000
,
# OrnsteinUhlenbeckProcess n-step annealing
"nb_steps_warmup_critic"
:
100
,
# steps for critic to warmup
"nb_steps_warmup_critic"
:
100
,
# steps for critic to warmup
"nb_steps_warmup_actor"
:
100
,
# steps for actor to warmup
"nb_steps_warmup_actor"
:
100
,
# steps for actor to warmup
"target_model_update"
:
1e-3
# target model update frequency
"target_model_update"
:
1e-3
# target model update frequency
...
@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase):
...
@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase):
target_model_update
=
1e-3
)
target_model_update
=
1e-3
)
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
agent
.
compile
([
Adam
(
lr
=
self
.
lr
*
1e-2
,
clipnorm
=
1.
),
Adam
(
lr
=
self
.
lr
,
clipnorm
=
1.
)],
metrics
=
[
'mae'
])
agent
.
compile
(
[
Adam
(
lr
=
self
.
lr
*
1e-2
,
clipnorm
=
1.
),
Adam
(
lr
=
self
.
lr
,
clipnorm
=
1.
)
],
metrics
=
[
'mae'
])
return
agent
return
agent
def
train
(
self
,
def
train
(
self
,
env
,
env
,
nb_steps
=
1000000
,
nb_steps
=
1000000
,
visualize
=
False
,
visualize
=
False
,
verbose
=
1
,
verbose
=
1
,
log_interval
=
10000
,
log_interval
=
10000
,
nb_max_episode_steps
=
200
,
nb_max_episode_steps
=
200
,
model_checkpoints
=
False
,
model_checkpoints
=
False
,
checkpoint_interval
=
100000
,
checkpoint_interval
=
100000
,
tensorboard
=
False
):
tensorboard
=
False
):
callbacks
=
[]
callbacks
=
[]
if
model_checkpoints
:
if
model_checkpoints
:
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)]
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)
]
if
tensorboard
:
if
tensorboard
:
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
...
@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase):
...
@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase):
Returns:
Returns:
KerasRL DQN object
KerasRL DQN object
"""
"""
agent
=
DQNAgentOverOptions
(
model
=
model
,
low_level_policies
=
self
.
low_level_policies
,
agent
=
DQNAgentOverOptions
(
nb_actions
=
self
.
nb_actions
,
memory
=
memory
,
model
=
model
,
nb_steps_warmup
=
self
.
nb_steps_warmup
,
target_model_update
=
self
.
target_model_update
,
low_level_policies
=
self
.
low_level_policies
,
policy
=
policy
,
enable_dueling_network
=
True
)
nb_actions
=
self
.
nb_actions
,
memory
=
memory
,
nb_steps_warmup
=
self
.
nb_steps_warmup
,
target_model_update
=
self
.
target_model_update
,
policy
=
policy
,
enable_dueling_network
=
True
)
agent
.
compile
(
Adam
(
lr
=
self
.
lr
),
metrics
=
[
'mae'
])
agent
.
compile
(
Adam
(
lr
=
self
.
lr
),
metrics
=
[
'mae'
])
return
agent
return
agent
def
train
(
self
,
def
train
(
self
,
env
,
env
,
nb_steps
=
1000000
,
nb_steps
=
1000000
,
visualize
=
False
,
visualize
=
False
,
nb_max_episode_steps
=
200
,
nb_max_episode_steps
=
200
,
tensorboard
=
False
,
tensorboard
=
False
,
model_checkpoints
=
False
,
model_checkpoints
=
False
,
checkpoint_interval
=
10000
):
checkpoint_interval
=
10000
):
callbacks
=
[]
callbacks
=
[]
if
model_checkpoints
:
if
model_checkpoints
:
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)]
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)
]
if
tensorboard
:
if
tensorboard
:
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
...
@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
...
@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
nb_episodes
=
5
,
nb_episodes
=
5
,
visualize
=
True
,
visualize
=
True
,
nb_max_episode_steps
=
400
,
nb_max_episode_steps
=
400
,
success_reward_threshold
=
100
):
success_reward_threshold
=
100
):
print
(
"Testing for {} episodes"
.
format
(
nb_episodes
))
print
(
"Testing for {} episodes"
.
format
(
nb_episodes
))
success_count
=
0
success_count
=
0
...
@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
...
@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
env
.
reset
()
env
.
reset
()
if
episode_reward
>=
success_reward_threshold
:
if
episode_reward
>=
success_reward_threshold
:
success_count
+=
1
success_count
+=
1
print
(
"Episode {}: steps:{}, reward:{}"
.
format
(
n
+
1
,
step
,
episode_reward
))
print
(
"Episode {}: steps:{}, reward:{}"
.
format
(
n
+
1
,
step
,
episode_reward
))
print
(
"
\n
Policy succeeded {} times!"
.
format
(
success_count
))
print
(
"
\n
Policy succeeded {} times!"
.
format
(
success_count
))
print
(
"Failures due to:"
)
print
(
"Failures due to:"
)
print
(
termination_reason_counter
)
print
(
termination_reason_counter
)
return
[
success_count
,
termination_reason_counter
]
return
[
success_count
,
termination_reason_counter
]
def
load_model
(
self
,
file_name
=
"test_weights.h5f"
):
def
load_model
(
self
,
file_name
=
"test_weights.h5f"
):
self
.
agent_model
.
load_weights
(
file_name
)
self
.
agent_model
.
load_weights
(
file_name
)
...
@@ -377,31 +396,43 @@ class DQNLearner(LearnerBase):
...
@@ -377,31 +396,43 @@ class DQNLearner(LearnerBase):
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action
]
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action
]
def
get_q_value_using_option_alias
(
self
,
observation
,
option_alias
):
def
get_q_value_using_option_alias
(
self
,
observation
,
option_alias
):
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action_num
]
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action_num
]
def
get_softq_value_using_option_alias
(
self
,
observation
,
option_alias
):
def
get_softq_value_using_option_alias
(
self
,
observation
,
option_alias
):
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
q_values
=
self
.
agent_model
.
get_modified_q_values
(
observation
)
q_values
=
self
.
agent_model
.
get_modified_q_values
(
observation
)
max_q_value
=
np
.
abs
(
np
.
max
(
q_values
))
max_q_value
=
np
.
abs
(
np
.
max
(
q_values
))
q_values
=
[
np
.
exp
(
q_value
/
max_q_value
)
for
q_value
in
q_values
]
q_values
=
[
np
.
exp
(
q_value
/
max_q_value
)
for
q_value
in
q_values
]
relevant
=
q_values
[
action_num
]
/
np
.
sum
(
q_values
)
relevant
=
q_values
[
action_num
]
/
np
.
sum
(
q_values
)
return
relevant
return
relevant
class
DQNAgentOverOptions
(
DQNAgent
):
def
__init__
(
self
,
model
,
low_level_policies
,
policy
=
None
,
test_policy
=
None
,
enable_double_dqn
=
True
,
enable_dueling_network
=
False
,
class
DQNAgentOverOptions
(
DQNAgent
):
dueling_type
=
'avg'
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
super
(
DQNAgentOverOptions
,
self
).
__init__
(
model
,
policy
,
test_policy
,
enable_double_dqn
,
enable_dueling_network
,
model
,
dueling_type
,
*
args
,
**
kwargs
)
low_level_policies
,
policy
=
None
,
test_policy
=
None
,
enable_double_dqn
=
True
,
enable_dueling_network
=
False
,
dueling_type
=
'avg'
,
*
args
,
**
kwargs
):
super
(
DQNAgentOverOptions
,
self
).
__init__
(
model
,
policy
,
test_policy
,
enable_double_dqn
,
enable_dueling_network
,
dueling_type
,
*
args
,
**
kwargs
)
self
.
low_level_policies
=
low_level_policies
self
.
low_level_policies
=
low_level_policies
if
low_level_policies
is
not
None
:
if
low_level_policies
is
not
None
:
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
def
__get_invalid_node_indices
(
self
):
def
__get_invalid_node_indices
(
self
):
"""Returns a list of option indices that are invalid according to
initiation conditions.
"""Returns a list of option indices that are invalid according to
"""
initiation conditions.
"""
invalid_node_indices
=
list
()
invalid_node_indices
=
list
()
for
index
,
option_alias
in
enumerate
(
self
.
low_level_policy_aliases
):
for
index
,
option_alias
in
enumerate
(
self
.
low_level_policy_aliases
):
self
.
low_level_policies
[
option_alias
].
reset_maneuver
()
self
.
low_level_policies
[
option_alias
].
reset_maneuver
()
...
@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
...
@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
q_values
[
node_index
]
=
-
np
.
inf
q_values
[
node_index
]
=
-
np
.
inf
return
q_values
return
q_values
backends/learner_base.py
View file @
f2171d2c
from
.
policy_base
import
PolicyBase
from
.policy_base
import
PolicyBase
import
numpy
as
np
import
numpy
as
np
...
@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase):
...
@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase):
setattr
(
self
,
prop
,
kwargs
.
get
(
prop
,
default
))
setattr
(
self
,
prop
,
kwargs
.
get
(
prop
,
default
))
def
train
(
self
,
def
train
(
self
,
env
,
env
,
nb_steps
=
50000
,
nb_steps
=
50000
,
visualize
=
False
,
visualize
=
False
,
nb_max_episode_steps
=
200
):
nb_max_episode_steps
=
200
):
"""Train the learning agent on the environment.
"""Train the learning agent on the environment.
Args:
Args:
...
...
backends/manual_policy.py
View file @
f2171d2c
from
.controller_base
import
ControllerBase
from
.controller_base
import
ControllerBase
class
ManualPolicy
(
ControllerBase
):
class
ManualPolicy
(
ControllerBase
):
"""Manual policy execution using nodes and edges."""
"""Manual policy execution using nodes and edges."""
def
__init__
(
self
,
env
,
low_level_policies
,
transition_adj
,
start_node_alias
):
def
__init__
(
self
,
env
,
low_level_policies
,
transition_adj
,
start_node_alias
):
"""Constructor for manual policy execution.
"""Constructor for manual policy execution.
Args:
Args:
...
@@ -13,12 +15,13 @@ class ManualPolicy(ControllerBase):
...
@@ -13,12 +15,13 @@ class ManualPolicy(ControllerBase):
start_node: starting node
start_node: starting node
"""
"""
super
(
ManualPolicy
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
ManualPolicy
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
adj
=
transition_adj
self
.
adj
=
transition_adj
def
_transition
(
self
):
def
_transition
(
self
):
"""Check if the current node's termination condition is met and if
"""Check if the current node's termination condition is met and if
it
i
t i
s possible to transition to another node, i.e. its initiation
is possible to transition to another node, i.e. its initiation
condition is met. This is an internal function.
condition is met. This is an internal function.
Returns the new node if a transition can happen, None otherwise
Returns the new node if a transition can happen, None otherwise
...
@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase):
...
@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase):
new_node
=
self
.
_transition
()
new_node
=
self
.
_transition
()
if
new_node
is
not
None
:
if
new_node
is
not
None
:
self
.
current_node
=
new_node
self
.
current_node
=
new_node
\ No newline at end of file
backends/mcts_learner.py
View file @
f2171d2c
...
@@ -2,15 +2,15 @@ from .controller_base import ControllerBase
...
@@ -2,15 +2,15 @@ from .controller_base import ControllerBase
import
numpy
as
np
import
numpy
as
np
import
pickle
import
pickle
class
MCTSLearner
(
ControllerBase
):
class
MCTSLearner
(
ControllerBase
):
"""Monte Carlo Tree Search implementation using the UCB1 and
"""Monte Carlo Tree Search implementation using the UCB1 and progressive
progressive widening approach as explained in Paxton et al (2017).
widening approach as explained in Paxton et al (2017)."""
"""
_ucb_vals
=
set
()
_ucb_vals
=
set
()
def
__init__
(
self
,
env
,
low_level_policies
,
def
__init__
(
self
,
env
,
low_level_policies
,
start_node_alias
,
start_node_alias
,
max_depth
=
10
):
max_depth
=
10
):
"""Constructor for MCTSLearner.
"""Constructor for MCTSLearner.
Args:
Args:
...
@@ -22,10 +22,12 @@ class MCTSLearner(ControllerBase):
...
@@ -22,10 +22,12 @@ class MCTSLearner(ControllerBase):
max_depth: max depth of the MCTS tree; default 10 levels
max_depth: max depth of the MCTS tree; default 10 levels
"""
"""
super
(
MCTSLearner
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
MCTSLearner
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
controller_args_defaults
=
{
self
.
controller_args_defaults
=
{
"predictor"
:
None
#P(s, o) learner class; forward pass should return the entire value from state s and option o
"predictor"
:
None
#P(s, o) learner class; forward pass should return the entire value from state s and option o
}
}
self
.
max_depth
=
max_depth
self
.
max_depth
=
max_depth
#: store current node alias
#: store current node alias
...
@@ -48,11 +50,17 @@ class MCTSLearner(ControllerBase):
...
@@ -48,11 +50,17 @@ class MCTSLearner(ControllerBase):
# populate root node
# populate root node
root_node_num
,
root_node_info
=
self
.
_create_node
(
self
.
curr_node_alias
)
root_node_num
,
root_node_info
=
self
.
_create_node
(
self
.
curr_node_alias
)
self
.
nodes
[
root_node_num
]
=
root_node_info
self
.
nodes
[
root_node_num
]
=
root_node_info
self
.
adj
[
root_node_num
]
=
set
()
# no children
self
.
adj
[
root_node_num
]
=
set
()
# no children
def
save_model
(
self
,
file_name
=
"mcts.pickle"
):
def
save_model
(
self
,
file_name
=
"mcts.pickle"
):
to_backup
=
{
'N'
:
self
.
N
,
'M'
:
self
.
M
,
'TR'
:
self
.
TR
,
'nodes'
:
self
.
nodes
,
to_backup
=
{
'adj'
:
self
.
adj
,
'new_node_num'
:
self
.
new_node_num
}
'N'
:
self
.
N
,
'M'
:
self
.
M
,
'TR'
:
self
.
TR
,
'nodes'
:
self
.
nodes
,
'adj'
:
self
.
adj
,
'new_node_num'
:
self
.
new_node_num
}
with
open
(
file_name
,
'wb'
)
as
handle
:
with
open
(
file_name
,
'wb'
)
as
handle
:
pickle
.
dump
(
to_backup
,
handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
pickle
.
dump
(
to_backup
,
handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
...
@@ -67,8 +75,8 @@ class MCTSLearner(ControllerBase):
...
@@ -67,8 +75,8 @@ class MCTSLearner(ControllerBase):
self
.
new_node_num
=
to_restore
[
'new_node_num'
]
self
.
new_node_num
=
to_restore
[
'new_node_num'
]
def
_create_node
(
self
,
low_level_policy
):
def
_create_node
(
self
,
low_level_policy
):
"""Create the node associated with curr_node_num, using the
"""Create the node associated with curr_node_num, using the
given low
given low
level policy.
level policy.
Args:
Args:
low_level_policy: the option's alias
low_level_policy: the option's alias
...
@@ -83,11 +91,10 @@ class MCTSLearner(ControllerBase):
...
@@ -83,11 +91,10 @@ class MCTSLearner(ControllerBase):
return
created_node_num
,
{
"policy"
:
low_level_policy
}
return
created_node_num
,
{
"policy"
:
low_level_policy
}
def
_to_discrete
(
self
,
observation
):
def
_to_discrete
(
self
,
observation
):
"""Converts observation to a discrete observation tuple. Also
"""Converts observation to a discrete observation tuple. Also append
append (a) whether we are following a vehicle, and (b) whether
(a) whether we are following a vehicle, and (b) whether there is a
there is a vehicle in the opposite lane in the approximately
vehicle in the opposite lane in the approximately the same x position.
the same x position. These values will be useful for Follow
These values will be useful for Follow and ChangeLane maneuvers.
and ChangeLane maneuvers.
Args:
Args:
observation: observation tuple from the environment
observation: observation tuple from the environment
...
@@ -96,9 +103,9 @@ class MCTSLearner(ControllerBase):
...
@@ -96,9 +103,9 @@ class MCTSLearner(ControllerBase):
"""
"""
dis_observation
=
''
dis_observation
=
''
for
item
in
observation
[
12
:
20
]:
for
item
in
observation
[
12
:
20
]:
if
type
(
item
)
==
bool
:
if
type
(
item
)
==
bool
:
dis_observation
+=
'1'
if
item
is
True
else
'0'
dis_observation
+=
'1'
if
item
is
True
else
'0'
if
type
(
item
)
==
int
and
item
in
[
0
,
1
]:
if
type
(
item
)
==
int
and
item
in
[
0
,
1
]:
dis_observation
+=
str
(
item
)
dis_observation
+=
str
(
item
)
env
=
self
.
current_node
.
env
env
=
self
.
current_node
.
env
...
@@ -128,9 +135,9 @@ class MCTSLearner(ControllerBase):
...
@@ -128,9 +135,9 @@ class MCTSLearner(ControllerBase):
def
_get_visitation_count
(
self
,
observation
,
option
=
None
):
def
_get_visitation_count
(
self
,
observation
,
option
=
None
):
"""Finds the visitation count of the discrete form of the observation.
"""Finds the visitation count of the discrete form of the observation.
If discrete observation not found, then inserted into self.N with
If discrete observation not found, then inserted into self.N with
value
value 0. Auto converts the observation into discrete form. If option
0. Auto converts the observation into discrete form. If option is not
is not None, then this uses self.M instead of self.N
None, then this uses self.M instead of self.N.
Args:
Args:
observation: observation tuple from the environment
observation: observation tuple from the environment
...
@@ -163,16 +170,18 @@ class MCTSLearner(ControllerBase):
...
@@ -163,16 +170,18 @@ class MCTSLearner(ControllerBase):
dis_observation
=
self
.
_to_discrete
(
observation
)
dis_observation
=
self
.
_to_discrete
(
observation
)
if
(
dis_observation
,
option
)
not
in
self
.
TR
:
if
(
dis_observation
,
option
)
not
in
self
.
TR
:
self
.
TR
[(
dis_observation
,
option
)]
=
0
self
.
TR
[(
dis_observation
,
option
)]
=
0
return
self
.
TR
[(
dis_observation
,
option
)]
/
(
1
+
self
.
_get_visitation_count
(
observation
,
option
))
return
self
.
TR
[(
dis_observation
,
option
)]
/
(
1
+
self
.
_get_visitation_count
(
observation
,
option
))
def
_get_possible_options
(
self
):
def
_get_possible_options
(
self
):
"""Returns a set of options that can be taken from the current node.
"""Returns a set of options that can be taken from the current node.
Goes through adjacency set of current node and finds which next nodes'
initiation condition is met.