Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
wise-move
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
2
Issues
2
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wise-lab
wise-move
Commits
f2171d2c
Commit
f2171d2c
authored
Nov 19, 2018
by
Aravind Balakrishnan
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'formatting' into 'master'
Formatting See merge request
!3
parents
e1fdb162
7abc600f
Changes
37
Hide whitespace changes
Inline
Side-by-side
Showing
37 changed files
with
1933 additions
and
1487 deletions
+1933
-1487
.gitignore
.gitignore
+3
-0
backends/baselines_learner.py
backends/baselines_learner.py
+9
-6
backends/controller_base.py
backends/controller_base.py
+14
-8
backends/kerasrl_learner.py
backends/kerasrl_learner.py
+74
-45
backends/learner_base.py
backends/learner_base.py
+5
-5
backends/manual_policy.py
backends/manual_policy.py
+8
-5
backends/mcts_learner.py
backends/mcts_learner.py
+61
-49
backends/online_mcts_controller.py
backends/online_mcts_controller.py
+14
-9
backends/policy_base.py
backends/policy_base.py
+2
-2
backends/rl_controller.py
backends/rl_controller.py
+7
-4
env/env_base.py
env/env_base.py
+28
-26
env/road_env.py
env/road_env.py
+3
-3
env/simple_intersection/__init__.py
env/simple_intersection/__init__.py
+1
-1
env/simple_intersection/features.py
env/simple_intersection/features.py
+24
-18
env/simple_intersection/road_geokinemetry.py
env/simple_intersection/road_geokinemetry.py
+10
-8
env/simple_intersection/road_networks.py
env/simple_intersection/road_networks.py
+20
-10
env/simple_intersection/shapes.py
env/simple_intersection/shapes.py
+5
-3
env/simple_intersection/simple_intersection_env.py
env/simple_intersection/simple_intersection_env.py
+173
-106
env/simple_intersection/utilities.py
env/simple_intersection/utilities.py
+17
-11
env/simple_intersection/vehicle_networks.py
env/simple_intersection/vehicle_networks.py
+7
-3
env/simple_intersection/vehicles.py
env/simple_intersection/vehicles.py
+31
-18
high_level_policy_main.py
high_level_policy_main.py
+114
-52
low_level_policy_main.py
low_level_policy_main.py
+86
-40
mcts.py
mcts.py
+122
-59
model_checker/LTL_property_base.py
model_checker/LTL_property_base.py
+17
-15
model_checker/atomic_propositions_base.py
model_checker/atomic_propositions_base.py
+6
-6
model_checker/parser.py
model_checker/parser.py
+480
-475
model_checker/scanner.py
model_checker/scanner.py
+303
-308
model_checker/simple_intersection/AP_dict.py
model_checker/simple_intersection/AP_dict.py
+0
-1
model_checker/simple_intersection/LTL_test.py
model_checker/simple_intersection/LTL_test.py
+28
-14
model_checker/simple_intersection/__init__.py
model_checker/simple_intersection/__init__.py
+1
-1
model_checker/simple_intersection/classes.py
model_checker/simple_intersection/classes.py
+6
-8
options/options_loader.py
options/options_loader.py
+43
-31
options/simple_intersection/maneuver_base.py
options/simple_intersection/maneuver_base.py
+57
-33
options/simple_intersection/maneuvers.py
options/simple_intersection/maneuvers.py
+78
-59
options/simple_intersection/mcts_maneuvers.py
options/simple_intersection/mcts_maneuvers.py
+64
-39
ppo2_training.py
ppo2_training.py
+12
-6
No files found.
.gitignore
0 → 100644
View file @
f2171d2c
*.pyc
*.py~
__pycache__
backends/baselines_learner.py
View file @
f2171d2c
...
...
@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
self
.
log_path
=
log_path
self
.
env
=
DummyVecEnv
([
lambda
:
env
])
#PPO2 requried a vectorized environment for parallel training
self
.
env
=
DummyVecEnv
([
lambda
:
env
])
#PPO2 requried a vectorized environment for parallel training
self
.
agent_model
=
self
.
create_agent
(
policy
,
tensorboard
)
def
get_default_policy
(
self
):
...
...
@@ -46,13 +48,13 @@ class PPO2Agent(LearnerBase):
return
MlpPolicy
def
create_agent
(
self
,
policy
,
tensorboard
):
"""Creates a PPO agent
"""Creates a PPO agent
.
Returns:
stable_baselines PPO2 object
Returns: stable_baselines PPO2 object
"""
if
tensorboard
:
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
,
tensorboard_log
=
self
.
log_path
)
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
,
tensorboard_log
=
self
.
log_path
)
else
:
return
PPO2
(
policy
,
self
.
env
,
verbose
=
1
)
...
...
@@ -100,7 +102,8 @@ class PPO2Agent(LearnerBase):
episode_rewards
[
-
1
]
+=
rewards
[
0
]
if
dones
[
0
]
or
current_step
>
nb_max_episode_steps
:
obs
=
self
.
env
.
reset
()
print
(
"Episode "
,
current_episode
,
"reward: "
,
episode_rewards
[
-
1
])
print
(
"Episode "
,
current_episode
,
"reward: "
,
episode_rewards
[
-
1
])
episode_rewards
.
append
(
0.0
)
current_episode
+=
1
current_step
=
0
...
...
backends/controller_base.py
View file @
f2171d2c
...
...
@@ -18,16 +18,17 @@ class ControllerBase(PolicyBase):
setattr
(
self
,
prop
,
kwargs
.
get
(
prop
,
default
))
def
can_transition
(
self
):
"""Returns boolean signifying whether we can transition. To be
implemented in subclass.
"""Returns boolean signifying whether we can transition.
To be implemented in subclass.
"""
raise
NotImplemented
(
self
.
__class__
.
__name__
+
\
"can_transition is not implemented."
)
def
do_transition
(
self
,
observation
):
"""Do a transition, assuming we can transition. To be
implemented in
subclass.
"""Do a transition, assuming we can transition. To be
implemented in
subclass.
Args:
observation: final observation from episodic step
...
...
@@ -37,7 +38,7 @@ class ControllerBase(PolicyBase):
"do_transition is not implemented."
)
def
set_current_node
(
self
,
node_alias
):
"""Sets the current node which is being executed
"""Sets the current node which is being executed
.
Args:
node: node alias of the node to be set
...
...
@@ -50,11 +51,13 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info
'''
def
step_current_node
(
self
,
visualize_low_level_steps
=
False
):
total_reward
=
0
self
.
node_terminal_state_reached
=
False
while
not
self
.
node_terminal_state_reached
:
observation
,
reward
,
terminal
,
info
=
self
.
low_level_step_current_node
()
observation
,
reward
,
terminal
,
info
=
self
.
low_level_step_current_node
(
)
if
visualize_low_level_steps
:
self
.
env
.
render
()
total_reward
+=
reward
...
...
@@ -70,9 +73,12 @@ class ControllerBase(PolicyBase):
Returns state after one step, step reward, episode_termination_flag, info
'''
def
low_level_step_current_node
(
self
):
u_ego
=
self
.
current_node
.
low_level_policy
(
self
.
current_node
.
get_reduced_features_tuple
())
u_ego
=
self
.
current_node
.
low_level_policy
(
self
.
current_node
.
get_reduced_features_tuple
())
feature
,
R
,
terminal
,
info
=
self
.
current_node
.
step
(
u_ego
)
self
.
node_terminal_state_reached
=
terminal
return
self
.
env
.
get_features_tuple
(),
R
,
self
.
env
.
termination_condition
,
info
return
self
.
env
.
get_features_tuple
(
),
R
,
self
.
env
.
termination_condition
,
info
backends/kerasrl_learner.py
View file @
f2171d2c
...
...
@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
"oup_mu"
:
0
,
# OrnsteinUhlenbeckProcess mu
"oup_sigma"
:
1
,
# OrnsteinUhlenbeckProcess sigma
"oup_sigma_min"
:
0.5
,
# OrnsteinUhlenbeckProcess sigma min
"oup_annealing_steps"
:
500000
,
# OrnsteinUhlenbeckProcess n-step annealing
"oup_annealing_steps"
:
500000
,
# OrnsteinUhlenbeckProcess n-step annealing
"nb_steps_warmup_critic"
:
100
,
# steps for critic to warmup
"nb_steps_warmup_actor"
:
100
,
# steps for actor to warmup
"target_model_update"
:
1e-3
# target model update frequency
...
...
@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase):
target_model_update
=
1e-3
)
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
agent
.
compile
([
Adam
(
lr
=
self
.
lr
*
1e-2
,
clipnorm
=
1.
),
Adam
(
lr
=
self
.
lr
,
clipnorm
=
1.
)],
metrics
=
[
'mae'
])
agent
.
compile
(
[
Adam
(
lr
=
self
.
lr
*
1e-2
,
clipnorm
=
1.
),
Adam
(
lr
=
self
.
lr
,
clipnorm
=
1.
)
],
metrics
=
[
'mae'
])
return
agent
def
train
(
self
,
env
,
nb_steps
=
1000000
,
visualize
=
False
,
verbose
=
1
,
log_interval
=
10000
,
nb_max_episode_steps
=
200
,
model_checkpoints
=
False
,
checkpoint_interval
=
100000
,
tensorboard
=
False
):
env
,
nb_steps
=
1000000
,
visualize
=
False
,
verbose
=
1
,
log_interval
=
10000
,
nb_max_episode_steps
=
200
,
model_checkpoints
=
False
,
checkpoint_interval
=
100000
,
tensorboard
=
False
):
callbacks
=
[]
if
model_checkpoints
:
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)]
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)
]
if
tensorboard
:
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
...
...
@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase):
Returns:
KerasRL DQN object
"""
agent
=
DQNAgentOverOptions
(
model
=
model
,
low_level_policies
=
self
.
low_level_policies
,
nb_actions
=
self
.
nb_actions
,
memory
=
memory
,
nb_steps_warmup
=
self
.
nb_steps_warmup
,
target_model_update
=
self
.
target_model_update
,
policy
=
policy
,
enable_dueling_network
=
True
)
agent
=
DQNAgentOverOptions
(
model
=
model
,
low_level_policies
=
self
.
low_level_policies
,
nb_actions
=
self
.
nb_actions
,
memory
=
memory
,
nb_steps_warmup
=
self
.
nb_steps_warmup
,
target_model_update
=
self
.
target_model_update
,
policy
=
policy
,
enable_dueling_network
=
True
)
agent
.
compile
(
Adam
(
lr
=
self
.
lr
),
metrics
=
[
'mae'
])
return
agent
def
train
(
self
,
env
,
nb_steps
=
1000000
,
visualize
=
False
,
nb_max_episode_steps
=
200
,
tensorboard
=
False
,
model_checkpoints
=
False
,
checkpoint_interval
=
10000
):
env
,
nb_steps
=
1000000
,
visualize
=
False
,
nb_max_episode_steps
=
200
,
tensorboard
=
False
,
model_checkpoints
=
False
,
checkpoint_interval
=
10000
):
callbacks
=
[]
if
model_checkpoints
:
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)]
callbacks
+=
[
ModelIntervalCheckpoint
(
'./checkpoints/checkpoint_weights.h5f'
,
interval
=
checkpoint_interval
)
]
if
tensorboard
:
callbacks
+=
[
TensorBoard
(
log_dir
=
'./logs'
)]
...
...
@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
nb_episodes
=
5
,
visualize
=
True
,
nb_max_episode_steps
=
400
,
success_reward_threshold
=
100
):
success_reward_threshold
=
100
):
print
(
"Testing for {} episodes"
.
format
(
nb_episodes
))
success_count
=
0
...
...
@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
env
.
reset
()
if
episode_reward
>=
success_reward_threshold
:
success_count
+=
1
print
(
"Episode {}: steps:{}, reward:{}"
.
format
(
n
+
1
,
step
,
episode_reward
))
print
(
"Episode {}: steps:{}, reward:{}"
.
format
(
n
+
1
,
step
,
episode_reward
))
print
(
"
\n
Policy succeeded {} times!"
.
format
(
success_count
))
print
(
"Failures due to:"
)
print
(
termination_reason_counter
)
print
(
"
\n
Policy succeeded {} times!"
.
format
(
success_count
))
print
(
"Failures due to:"
)
print
(
termination_reason_counter
)
return
[
success_count
,
termination_reason_counter
]
return
[
success_count
,
termination_reason_counter
]
def
load_model
(
self
,
file_name
=
"test_weights.h5f"
):
self
.
agent_model
.
load_weights
(
file_name
)
...
...
@@ -377,31 +396,43 @@ class DQNLearner(LearnerBase):
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action
]
def
get_q_value_using_option_alias
(
self
,
observation
,
option_alias
):
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
return
self
.
agent_model
.
get_modified_q_values
(
observation
)[
action_num
]
def
get_softq_value_using_option_alias
(
self
,
observation
,
option_alias
):
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
q_values
=
self
.
agent_model
.
get_modified_q_values
(
observation
)
max_q_value
=
np
.
abs
(
np
.
max
(
q_values
))
q_values
=
[
np
.
exp
(
q_value
/
max_q_value
)
for
q_value
in
q_values
]
relevant
=
q_values
[
action_num
]
/
np
.
sum
(
q_values
)
q_values
=
[
np
.
exp
(
q_value
/
max_q_value
)
for
q_value
in
q_values
]
relevant
=
q_values
[
action_num
]
/
np
.
sum
(
q_values
)
return
relevant
class
DQNAgentOverOptions
(
DQNAgent
):
def
__init__
(
self
,
model
,
low_level_policies
,
policy
=
None
,
test_policy
=
None
,
enable_double_dqn
=
True
,
enable_dueling_network
=
False
,
dueling_type
=
'avg'
,
*
args
,
**
kwargs
):
super
(
DQNAgentOverOptions
,
self
).
__init__
(
model
,
policy
,
test_policy
,
enable_double_dqn
,
enable_dueling_network
,
dueling_type
,
*
args
,
**
kwargs
)
class
DQNAgentOverOptions
(
DQNAgent
):
def
__init__
(
self
,
model
,
low_level_policies
,
policy
=
None
,
test_policy
=
None
,
enable_double_dqn
=
True
,
enable_dueling_network
=
False
,
dueling_type
=
'avg'
,
*
args
,
**
kwargs
):
super
(
DQNAgentOverOptions
,
self
).
__init__
(
model
,
policy
,
test_policy
,
enable_double_dqn
,
enable_dueling_network
,
dueling_type
,
*
args
,
**
kwargs
)
self
.
low_level_policies
=
low_level_policies
if
low_level_policies
is
not
None
:
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
self
.
low_level_policy_aliases
=
list
(
self
.
low_level_policies
.
keys
())
def
__get_invalid_node_indices
(
self
):
"""Returns a list of option indices that are invalid according to
initiation conditions.
"""
"""Returns a list of option indices that are invalid according to
initiation conditions.
"""
invalid_node_indices
=
list
()
for
index
,
option_alias
in
enumerate
(
self
.
low_level_policy_aliases
):
self
.
low_level_policies
[
option_alias
].
reset_maneuver
()
...
...
@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
q_values
[
node_index
]
=
-
np
.
inf
return
q_values
backends/learner_base.py
View file @
f2171d2c
from
.
policy_base
import
PolicyBase
from
.policy_base
import
PolicyBase
import
numpy
as
np
...
...
@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase):
setattr
(
self
,
prop
,
kwargs
.
get
(
prop
,
default
))
def
train
(
self
,
env
,
nb_steps
=
50000
,
visualize
=
False
,
nb_max_episode_steps
=
200
):
env
,
nb_steps
=
50000
,
visualize
=
False
,
nb_max_episode_steps
=
200
):
"""Train the learning agent on the environment.
Args:
...
...
backends/manual_policy.py
View file @
f2171d2c
from
.controller_base
import
ControllerBase
class
ManualPolicy
(
ControllerBase
):
"""Manual policy execution using nodes and edges."""
def
__init__
(
self
,
env
,
low_level_policies
,
transition_adj
,
start_node_alias
):
def
__init__
(
self
,
env
,
low_level_policies
,
transition_adj
,
start_node_alias
):
"""Constructor for manual policy execution.
Args:
...
...
@@ -13,12 +15,13 @@ class ManualPolicy(ControllerBase):
start_node: starting node
"""
super
(
ManualPolicy
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
ManualPolicy
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
adj
=
transition_adj
def
_transition
(
self
):
"""Check if the current node's termination condition is met and if
i
t i
s possible to transition to another node, i.e. its initiation
"""Check if the current node's termination condition is met and if
it
is possible to transition to another node, i.e. its initiation
condition is met. This is an internal function.
Returns the new node if a transition can happen, None otherwise
...
...
@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase):
new_node
=
self
.
_transition
()
if
new_node
is
not
None
:
self
.
current_node
=
new_node
\ No newline at end of file
self
.
current_node
=
new_node
backends/mcts_learner.py
View file @
f2171d2c
...
...
@@ -2,15 +2,15 @@ from .controller_base import ControllerBase
import
numpy
as
np
import
pickle
class
MCTSLearner
(
ControllerBase
):
"""Monte Carlo Tree Search implementation using the UCB1 and
progressive widening approach as explained in Paxton et al (2017).
"""
"""Monte Carlo Tree Search implementation using the UCB1 and progressive
widening approach as explained in Paxton et al (2017)."""
_ucb_vals
=
set
()
def
__init__
(
self
,
env
,
low_level_policies
,
start_node_alias
,
max_depth
=
10
):
def
__init__
(
self
,
env
,
low_level_policies
,
start_node_alias
,
max_depth
=
10
):
"""Constructor for MCTSLearner.
Args:
...
...
@@ -22,10 +22,12 @@ class MCTSLearner(ControllerBase):
max_depth: max depth of the MCTS tree; default 10 levels
"""
super
(
MCTSLearner
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
MCTSLearner
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
controller_args_defaults
=
{
"predictor"
:
None
#P(s, o) learner class; forward pass should return the entire value from state s and option o
"predictor"
:
None
#P(s, o) learner class; forward pass should return the entire value from state s and option o
}
self
.
max_depth
=
max_depth
#: store current node alias
...
...
@@ -48,11 +50,17 @@ class MCTSLearner(ControllerBase):
# populate root node
root_node_num
,
root_node_info
=
self
.
_create_node
(
self
.
curr_node_alias
)
self
.
nodes
[
root_node_num
]
=
root_node_info
self
.
adj
[
root_node_num
]
=
set
()
# no children
self
.
adj
[
root_node_num
]
=
set
()
# no children
def
save_model
(
self
,
file_name
=
"mcts.pickle"
):
to_backup
=
{
'N'
:
self
.
N
,
'M'
:
self
.
M
,
'TR'
:
self
.
TR
,
'nodes'
:
self
.
nodes
,
'adj'
:
self
.
adj
,
'new_node_num'
:
self
.
new_node_num
}
to_backup
=
{
'N'
:
self
.
N
,
'M'
:
self
.
M
,
'TR'
:
self
.
TR
,
'nodes'
:
self
.
nodes
,
'adj'
:
self
.
adj
,
'new_node_num'
:
self
.
new_node_num
}
with
open
(
file_name
,
'wb'
)
as
handle
:
pickle
.
dump
(
to_backup
,
handle
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
...
...
@@ -67,8 +75,8 @@ class MCTSLearner(ControllerBase):
self
.
new_node_num
=
to_restore
[
'new_node_num'
]
def
_create_node
(
self
,
low_level_policy
):
"""Create the node associated with curr_node_num, using the
given low
level policy.
"""Create the node associated with curr_node_num, using the
given low
level policy.
Args:
low_level_policy: the option's alias
...
...
@@ -83,11 +91,10 @@ class MCTSLearner(ControllerBase):
return
created_node_num
,
{
"policy"
:
low_level_policy
}
def
_to_discrete
(
self
,
observation
):
"""Converts observation to a discrete observation tuple. Also
append (a) whether we are following a vehicle, and (b) whether
there is a vehicle in the opposite lane in the approximately
the same x position. These values will be useful for Follow
and ChangeLane maneuvers.
"""Converts observation to a discrete observation tuple. Also append
(a) whether we are following a vehicle, and (b) whether there is a
vehicle in the opposite lane in the approximately the same x position.
These values will be useful for Follow and ChangeLane maneuvers.
Args:
observation: observation tuple from the environment
...
...
@@ -96,9 +103,9 @@ class MCTSLearner(ControllerBase):
"""
dis_observation
=
''
for
item
in
observation
[
12
:
20
]:
if
type
(
item
)
==
bool
:
if
type
(
item
)
==
bool
:
dis_observation
+=
'1'
if
item
is
True
else
'0'
if
type
(
item
)
==
int
and
item
in
[
0
,
1
]:
if
type
(
item
)
==
int
and
item
in
[
0
,
1
]:
dis_observation
+=
str
(
item
)
env
=
self
.
current_node
.
env
...
...
@@ -128,9 +135,9 @@ class MCTSLearner(ControllerBase):
def
_get_visitation_count
(
self
,
observation
,
option
=
None
):
"""Finds the visitation count of the discrete form of the observation.
If discrete observation not found, then inserted into self.N with
value 0. Auto converts the observation into discrete form. If option
is not None, then this uses self.M instead of self.N
If discrete observation not found, then inserted into self.N with
value
0. Auto converts the observation into discrete form. If option is not
None, then this uses self.M instead of self.N.
Args:
observation: observation tuple from the environment
...
...
@@ -163,16 +170,18 @@ class MCTSLearner(ControllerBase):
dis_observation
=
self
.
_to_discrete
(
observation
)
if
(
dis_observation
,
option
)
not
in
self
.
TR
:
self
.
TR
[(
dis_observation
,
option
)]
=
0
return
self
.
TR
[(
dis_observation
,
option
)]
/
(
1
+
self
.
_get_visitation_count
(
observation
,
option
))
return
self
.
TR
[(
dis_observation
,
option
)]
/
(
1
+
self
.
_get_visitation_count
(
observation
,
option
))
def
_get_possible_options
(
self
):
"""Returns a set of options that can be taken from the current node.
Goes through adjacency set of current node and finds which next nodes'
initiation condition is met.
Goes through adjacency set of current node and finds which next
nodes' initiation condition is met.
"""
all_options
=
set
(
self
.
low_level_policies
.
keys
())
# Filter nodes whose initiation condition are true
filtered_options
=
set
()
for
option_alias
in
all_options
:
...
...
@@ -201,9 +210,9 @@ class MCTSLearner(ControllerBase):
return
visited_aliases
def
_ucb_adjusted_q
(
self
,
observation
,
C
=
1
):
"""Computes Q_star(observation, option_i) plus the UCB term, which
is C*[predictor(observation, option_i)]/[1+N(observation, option_i)],
for
all option_i in the adjacency set of the current node.
"""Computes Q_star(observation, option_i) plus the UCB term, which
is
C*[predictor(observation, option_i)]/[1+N(observation, option_i)], for
all option_i in the adjacency set of the current node.
Args:
observation: observation tuple from the environment
...
...
@@ -211,14 +220,15 @@ class MCTSLearner(ControllerBase):
Returns Q values for next nodes
"""
Q
=
{}
Q1
,
Q2
=
{},
{}
# debug
Q1
,
Q2
=
{},
{}
# debug
dis_observation
=
self
.
_to_discrete
(
observation
)
next_option_nums
=
self
.
adj
[
self
.
curr_node_num
]
for
next_option_num
in
next_option_nums
:
next_option
=
self
.
nodes
[
next_option_num
][
"policy"
]
Q1
[
next_option
]
=
(
self
.
_get_q_star
(
observation
,
next_option
)
+
200
)
/
400
Q1
[
next_option
]
=
(
self
.
_get_q_star
(
observation
,
next_option
)
+
200
)
/
400
Q
[(
dis_observation
,
next_option
)]
=
\
Q1
[
next_option
]
Q2
[
next_option
]
=
C
*
\
...
...
@@ -243,17 +253,17 @@ class MCTSLearner(ControllerBase):
relevant_rewards
=
[
value
for
key
,
value
in
self
.
TR
.
items
()
\
if
key
[
0
]
==
dis_observation
]
sum_rewards
=
np
.
sum
(
relevant_rewards
)
return
sum_rewards
/
(
1
+
self
.
_get_visitation_count
(
observation
))
return
sum_rewards
/
(
1
+
self
.
_get_visitation_count
(
observation
))
def
_select
(
self
,
observation
,
depth
=
0
,
visualize
=
False
):
"""MCTS selection function. For representation, we only use
the
discrete part of the observation.
"""MCTS selection function. For representation, we only use
the
discrete part of the observation.
Args:
observation: observation tuple from the environment
depth: current depth, starts from root node, hence 0 by default
visualize: whether or not to visualize low level steps
Returns the sum of values from the given observation.
"""
...
...
@@ -266,7 +276,8 @@ class MCTSLearner(ControllerBase):
if
is_terminal
or
max_depth_reached
:
# print('MCTS went %d nodes deep' % depth)
return
self
.
_value
(
observation
),
0
# TODO: replace with final goal reward
return
self
.
_value
(
observation
),
0
# TODO: replace with final goal reward
Ns
=
self
.
_get_visitation_count
(
observation
)
Nchildren
=
len
(
self
.
adj
[
self
.
curr_node_num
])
...
...
@@ -288,23 +299,24 @@ class MCTSLearner(ControllerBase):
self
.
adj
[
self
.
curr_node_num
].
add
(
new_node_num
)
# Find o_star and do a transition, i.e. update curr_node
# Simulate / lookup; first change next
next_observation
,
episode_R
,
o_star
=
self
.
do_transition
(
observation
,
visualize
=
visualize
)
# Simulate / lookup; first change next
next_observation
,
episode_R
,
o_star
=
self
.
do_transition
(
observation
,
visualize
=
visualize
)
# Recursively select next node
remaining_v
,
all_ep_R
=
self
.
_select
(
next_observation
,
depth
+
1
,
visualize
=
visualize
)
remaining_v
,
all_ep_R
=
self
.
_select
(
next_observation
,
depth
+
1
,
visualize
=
visualize
)
# Update values
self
.
N
[
dis_observation
]
+=
1
self
.
M
[(
dis_observation
,
o_star
)]
+=
1
self
.
TR
[(
dis_observation
,
o_star
)]
+=
(
episode_R
+
remaining_v
)
return
self
.
_value
(
observation
),
all_ep_R
+
episode_R
return
self
.
_value
(
observation
),
all_ep_R
+
episode_R
def
traverse
(
self
,
observation
,
visualize
=
False
):
"""Do a complete traversal from root to leaf. Assumes the
environment
is reset and we are at the root node.
"""Do a complete traversal from root to leaf. Assumes the
environment
is reset and we are at the root node.
Args:
observation: observation from the environment
...
...
@@ -316,13 +328,13 @@ class MCTSLearner(ControllerBase):
return
self
.
_select
(
observation
,
visualize
=
visualize
)
def
do_transition
(
self
,
observation
,
visualize
=
False
):
"""Do a transition using UCB metric, with the latest observation
from
the episodic step.
"""Do a transition using UCB metric, with the latest observation
from
the episodic step.
Args:
observation: final observation from episodic step
visualize: whether or not to visualize low level steps
visualize: whether or not to visualize low level steps
Returns o_star using UCB metric
"""
...
...
@@ -368,4 +380,4 @@ class MCTSLearner(ControllerBase):
next_keys
,
next_values
=
list
(
Q1
.
keys
()),
list
(
Q1
.
values
())
o_star
=
next_keys
[
np
.
argmax
(
next_values
)]
print
(
Q1
)
return
o_star
\ No newline at end of file
return
o_star
backends/online_mcts_controller.py
View file @
f2171d2c
...
...
@@ -3,8 +3,9 @@ from .mcts_learner import MCTSLearner
import
tqdm
import
numpy
as
np
class
OnlineMCTSController
(
ControllerBase
):
"""Online MCTS"""
"""Online MCTS
.
"""
def
__init__
(
self
,
env
,
low_level_policies
,
start_node_alias
):
"""Constructor for manual policy execution.
...
...
@@ -13,12 +14,13 @@ class OnlineMCTSController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super
(
OnlineMCTSController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
super
(
OnlineMCTSController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
curr_node_alias
=
start_node_alias
self
.
controller_args_defaults
=
{
"predictor"
:
None
,
"max_depth"
:
5
,
# MCTS depth
"nb_traversals"
:
30
,
# MCTS traversals before decision
"max_depth"
:
5
,
# MCTS depth
"nb_traversals"
:
30
,
# MCTS traversals before decision
}
def
set_current_node
(
self
,
node_alias
):
...
...
@@ -48,12 +50,14 @@ class OnlineMCTSController(ControllerBase):
env_before_mcts
=
orig_env
.
copy
()
self
.
change_low_level_references
(
env_before_mcts
)
print
(
'Current Node: %s'
%
self
.
curr_node_alias
)
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,
self
.
curr_node_alias
)
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,