Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in
Toggle navigation
wise-move
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
2
Issues
2
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Operations
Operations
Incidents
Environments
Packages & Registries
Packages & Registries
Container Registry
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wise-lab
wise-move
Commits
7a31ba33
Commit
7a31ba33
authored
Feb 06, 2019
by
Jae Young Lee
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' into Train_0.1m_steps_and_improve_Wait
parents
f25adee3
54079509
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
557 additions
and
888 deletions
+557
-888
backends/__init__.py
backends/__init__.py
+1
-1
backends/kerasrl_learner.py
backends/kerasrl_learner.py
+12
-8
backends/mcts_controller.py
backends/mcts_controller.py
+109
-0
backends/mcts_learner.py
backends/mcts_learner.py
+358
-302
backends/trained_policies/highlevel/highlevel_weights.h5f
backends/trained_policies/highlevel/highlevel_weights.h5f
+0
-0
mcts.py
mcts.py
+68
-246
mcts_config.json
mcts_config.json
+5
-5
options/options_loader.py
options/options_loader.py
+4
-11
options/simple_intersection/mcts_maneuvers.py
options/simple_intersection/mcts_maneuvers.py
+0
-315
No files found.
backends/__init__.py
View file @
7a31ba33
...
...
@@ -2,4 +2,4 @@ from .manual_policy import ManualPolicy
from
.mcts_learner
import
MCTSLearner
from
.rl_controller
import
RLController
from
.kerasrl_learner
import
DDPGLearner
,
DQNLearner
from
.online_mcts_controller
import
OnlineMCTSController
\ No newline at end of file
from
.mcts_controller
import
MCTSController
\ No newline at end of file
backends/kerasrl_learner.py
View file @
7a31ba33
...
...
@@ -13,7 +13,7 @@ from rl.policy import GreedyQPolicy, EpsGreedyQPolicy, MaxBoltzmannQPolicy
from
rl.callbacks
import
ModelIntervalCheckpoint
import
numpy
as
np
import
copy
class
DDPGLearner
(
LearnerBase
):
def
__init__
(
self
,
...
...
@@ -275,9 +275,6 @@ class DQNLearner(LearnerBase):
model
.
add
(
Flatten
(
input_shape
=
(
1
,
)
+
self
.
input_shape
))
model
.
add
(
Dense
(
64
,
activation
=
'relu'
))
model
.
add
(
Dense
(
64
,
activation
=
'relu'
))
model
.
add
(
Dense
(
64
,
activation
=
'relu'
))
model
.
add
(
Dense
(
64
,
activation
=
'relu'
))
model
.
add
(
Dense
(
64
,
activation
=
'relu'
))
model
.
add
(
Dense
(
64
,
activation
=
'tanh'
))
model
.
add
(
Dense
(
self
.
nb_actions
))
print
(
model
.
summary
())
...
...
@@ -386,9 +383,10 @@ class DQNLearner(LearnerBase):
termination_reason_counter
[
termination_reason
]
+=
1
else
:
termination_reason_counter
[
termination_reason
]
=
1
env
.
reset
()
if
e
pisode_reward
>=
success_reward_threshol
d
:
#TODO: remove below env-specific code
if
e
nv
.
env
.
goal_achieve
d
:
success_count
+=
1
env
.
reset
()
print
(
"Episode {}: steps:{}, reward:{}"
.
format
(
n
+
1
,
step
,
episode_reward
))
...
...
@@ -416,9 +414,14 @@ class DQNLearner(LearnerBase):
action_num
=
self
.
agent_model
.
low_level_policy_aliases
.
index
(
option_alias
)
q_values
=
self
.
agent_model
.
get_modified_q_values
(
observation
)
max_q_value
=
np
.
abs
(
np
.
max
(
q_values
))
q_values
=
[
np
.
exp
(
q_value
/
max_q_value
)
for
q_value
in
q_values
]
# print('softq q_values are %s' % dict(zip(self.agent_model.low_level_policy_aliases, q_values)))
# oq_values = copy.copy(q_values)
if
q_values
[
action_num
]
==
-
np
.
inf
:
return
0
max_q_value
=
np
.
max
(
q_values
)
q_values
=
[
np
.
exp
(
q_value
-
max_q_value
)
for
q_value
in
q_values
]
relevant
=
q_values
[
action_num
]
/
np
.
sum
(
q_values
)
# print('softq: %s -> %s' % (oq_values, relevant))
return
relevant
...
...
@@ -543,6 +546,7 @@ class DQNAgentOverOptions(DQNAgent):
self
.
recent_observation
=
observation
self
.
recent_action
=
action
# print('forward gives %s from %s' % (action, dict(zip(self.low_level_policy_aliases, q_values))))
return
action
def
get_modified_q_values
(
self
,
observation
):
...
...
backends/
online_
mcts_controller.py
→
backends/mcts_controller.py
View file @
7a31ba33
...
...
@@ -3,9 +3,8 @@ from .mcts_learner import MCTSLearner
import
tqdm
import
numpy
as
np
class
OnlineMCTSController
(
ControllerBase
):
"""Online MCTS."""
class
MCTSController
(
ControllerBase
):
"""MCTS Controller."""
def
__init__
(
self
,
env
,
low_level_policies
,
start_node_alias
):
"""Constructor for manual policy execution.
...
...
@@ -14,13 +13,15 @@ class OnlineMCTSController(ControllerBase):
env: env instance
low_level_policies: low level policies dictionary
"""
super
(
Online
MCTSController
,
self
).
__init__
(
env
,
low_level_policies
,
super
(
MCTSController
,
self
).
__init__
(
env
,
low_level_policies
,
start_node_alias
)
self
.
curr_node_alias
=
start_node_alias
self
.
controller_args_defaults
=
{
"predictor"
:
None
,
"max_depth"
:
5
,
# MCTS depth
"nb_traversals"
:
30
,
# MCTS traversals before decision
"max_depth"
:
10
,
# MCTS depth
"nb_traversals"
:
100
,
# MCTS traversals before decision
"debug"
:
False
,
"rollout_timeout"
:
500
}
def
set_current_node
(
self
,
node_alias
):
...
...
@@ -30,11 +31,21 @@ class OnlineMCTSController(ControllerBase):
self
.
env
.
set_ego_info_text
(
node_alias
)
def
change_low_level_references
(
self
,
env_copy
):
# Create a copy of the environment and change references in low level policies.
"""Change references in low level policies by updating the environment
with the copy of the environment.
Args:
env_copy: reference to copy of the environment
"""
self
.
env
=
env_copy
for
policy
in
self
.
low_level_policies
.
values
():
policy
.
env
=
env_copy
def
check_env
(
self
,
x
):
"""Prints the object id of the environment. Debugging function."""
print
(
'%s: self.env is %s'
%
(
x
,
str
(
id
(
self
.
env
))))
def
can_transition
(
self
):
return
not
self
.
env
.
is_terminal
()
...
...
@@ -45,29 +56,54 @@ class OnlineMCTSController(ControllerBase):
"predictor is not set. Use set_controller_args()."
)
# Store the env at this point
orig_env
=
self
.
env
# self.check_env('i')
np
.
random
.
seed
()
# Change low level references before init MCTSLearner instance
env_before_mcts
=
orig_env
.
copy
()
self
.
change_low_level_references
(
env_before_mcts
)
print
(
'Current Node: %s'
%
self
.
curr_node_alias
)
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,
self
.
curr_node_alias
)
mcts
.
max_depth
=
self
.
max_depth
mcts
.
set_controller_args
(
predictor
=
self
.
predictor
)
# self.check_env('b4')
# print('Current Node: %s' % self.curr_node_alias)
if
not
hasattr
(
self
,
'mcts'
):
if
self
.
debug
:
print
(
'Creating MCTS Tree: max depth {}'
.
format
(
self
.
max_depth
))
self
.
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,
max_depth
=
self
.
max_depth
,
debug
=
self
.
debug
,
rollout_timeout
=
self
.
rollout_timeout
)
self
.
mcts
.
set_controller_args
(
predictor
=
self
.
predictor
)
if
self
.
debug
:
print
(
''
)
# Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for
num_epoch
in
range
(
self
.
nb_traversals
):
# tqdm.tqdm(range(self.nb_traversals)):
mcts
.
curr_node_num
=
0
num_epoch
=
0
if
not
self
.
debug
:
progress
=
tqdm
.
tqdm
(
total
=
self
.
nb_traversals
-
self
.
mcts
.
tree
.
root
.
N
)
while
num_epoch
<
self
.
nb_traversals
:
# tqdm
if
self
.
mcts
.
tree
.
root
.
N
>=
self
.
nb_traversals
:
break
env_begin_epoch
=
env_before_mcts
.
copy
()
self
.
change_low_level_references
(
env_begin_epoch
)
# self.check_env('e%d' % num_epoch)
init_obs
=
self
.
env
.
get_features_tuple
()
v
,
all_ep_R
=
mcts
.
traverse
(
init_obs
)
self
.
mcts
.
env
=
env_begin_epoch
if
self
.
debug
:
print
(
'Search %d: '
%
num_epoch
,
end
=
' '
)
success
=
self
.
mcts
.
search
(
init_obs
)
num_epoch
+=
1
if
not
self
.
debug
:
progress
.
update
(
1
)
if
not
self
.
debug
:
progress
.
close
()
self
.
change_low_level_references
(
orig_env
)
# self.check_env('p')
# Find the nodes from the root node
mcts
.
curr_node_num
=
0
print
(
'%s'
%
mcts
.
_to_discrete
(
self
.
env
.
get_features_tuple
()))
node_after_transition
=
mcts
.
get_best_node
(
self
.
env
.
get_features_tuple
(),
use_ucb
=
False
)
print
(
'MCTS suggested next option: %s'
%
node_after_transition
)
self
.
set_current_node
(
node_after_transition
)
# print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition
=
self
.
mcts
.
best_action
(
self
.
mcts
.
tree
.
root
,
0
)
if
self
.
debug
:
print
(
'MCTS suggested next option: %s'
%
node_after_transition
)
p
=
{
'overall'
:
self
.
mcts
.
tree
.
root
.
Q
*
1.0
/
self
.
mcts
.
tree
.
root
.
N
}
for
edge
in
self
.
mcts
.
tree
.
root
.
edges
.
keys
():
next_node
=
self
.
mcts
.
tree
.
nodes
[
self
.
mcts
.
tree
.
root
.
edges
[
edge
]]
p
[
edge
]
=
next_node
.
Q
*
1.0
/
next_node
.
N
if
self
.
debug
:
print
(
'Btw: %s'
%
str
(
p
))
self
.
mcts
.
tree
.
reconstruct
(
node_after_transition
)
self
.
set_current_node
(
node_after_transition
)
\ No newline at end of file
backends/mcts_learner.py
View file @
7a31ba33
This diff is collapsed.
Click to expand it.
backends/trained_policies/highlevel/highlevel_weights.h5f
View file @
7a31ba33
No preview for this file type
mcts.py
View file @
7a31ba33
...
...
@@ -4,7 +4,7 @@ from backends import DDPGLearner, DQNLearner, MCTSLearner
import
numpy
as
np
import
tqdm
import
argparse
import
time
,
datetime
import
sys
...
...
@@ -28,310 +28,132 @@ class Logger(object):
sys
.
stdout
=
Logger
()
# TODO: make a separate file for this function.
def
mcts_training
(
nb_traversals
,
save_every
=
20
,
visualize
=
False
,
load_saved
=
False
,
save_file
=
"mcts.pickle"
):
"""Do RL of the low-level policy of the given maneuver and test it.
Args:
nb_traversals: number of MCTS traversals
save_every: save at every these many traversals
visualize: visualization / rendering
"""
# initialize the numpy random number generator
np
.
random
.
seed
()
# load options graph
options
=
OptionsGraph
(
"mcts_config.json"
,
SimpleIntersectionEnv
)
options
.
load_trained_low_level_policies
()
agent
=
DQNLearner
(
input_shape
=
(
50
,
),
nb_actions
=
options
.
get_number_of_nodes
(),
low_level_policies
=
options
.
maneuvers
)
agent
.
load_model
(
"backends/trained_policies/highlevel/highlevel_weights.h5f"
)
options
.
set_controller_args
(
predictor
=
agent
.
get_softq_value_using_option_alias
)
options
.
controller
.
max_depth
=
20
if
load_saved
:
options
.
controller
.
load_model
(
save_file
)
total_epochs
=
nb_traversals
//
save_every
trav_num
=
1
print
(
'Total number of epochs = %d'
%
total_epochs
)
for
num_epoch
in
range
(
total_epochs
):
last_rewards
=
[]
beg_trav_num
=
trav_num
for
num_traversal
in
tqdm
.
tqdm
(
range
(
save_every
)):
options
.
controller
.
curr_node_num
=
0
init_obs
=
options
.
reset
()
v
,
all_ep_R
=
options
.
controller
.
traverse
(
init_obs
,
visualize
=
visualize
)
last_rewards
+=
[
all_ep_R
]
trav_num
+=
1
options
.
controller
.
save_model
(
save_file
)
success
=
lambda
x
:
x
>
50
success_rate
=
np
.
sum
(
list
(
map
(
success
,
last_rewards
)))
/
(
len
(
last_rewards
)
*
1.0
)
print
(
'success rate: %f'
%
success_rate
)
print
(
'Average Reward (%d-%d): %f
\n
'
%
(
beg_trav_num
,
trav_num
-
1
,
np
.
mean
(
last_rewards
)))
def
mcts_evaluation
(
nb_traversals
,
num_trials
=
5
,
def
mcts_evaluation
(
depth
,
nb_traversals
,
nb_episodes
,
nb_trials
,
visualize
=
False
,
save_file
=
"mcts.pickle"
,
pretrained
=
False
):
debug
=
False
):
"""Do RL of the low-level policy of the given maneuver and test it.
Args:
nb_traversals: number of MCTS traversals
save_every: save at every these many traversals
depth: depth of each tree search
nb_traversals: number of MCTS traversals per episodes
nb_episodes: number of episodes per trial
nb_trials: number of trials
visualize: visualization / rendering
debug: whether or not to show debug information
"""
# initialize the numpy random number generator
np
.
random
.
seed
()
# load options graph
options
=
OptionsGraph
(
"mcts_config.json"
,
SimpleIntersectionEnv
)
# load config and maneuvers
options
=
OptionsGraph
(
"mcts_config.json"
,
SimpleIntersectionEnv
,
randomize_special_scenarios
=
True
)
options
.
load_trained_low_level_policies
()
# load high level policy for UCT prediction
agent
=
DQNLearner
(
input_shape
=
(
50
,
),
nb_actions
=
options
.
get_number_of_nodes
(),
low_level_policies
=
options
.
maneuvers
)
agent
.
load_model
(
"backends/trained_policies/highlevel/highlevel_weights.h5f"
)
options
.
set_controller_args
(
predictor
=
agent
.
get_softq_value_using_option_alias
)
options
.
controller
.
max_depth
=
20
if
pretrained
:
save_file
=
"backends/trained_policies/mcts/"
+
save_file
success_list
=
[]
print
(
'Total number of trials = %d'
%
num_trials
)
for
trial
in
range
(
num_trials
):
num_successes
=
0
options
.
controller
.
load_model
(
save_file
)
for
num_traversal
in
tqdm
.
tqdm
(
range
(
nb_traversals
)):
options
.
controller
.
curr_node_num
=
0
init_obs
=
options
.
reset
()
v
,
all_ep_R
=
options
.
controller
.
traverse
(
init_obs
,
visualize
=
visualize
)
if
all_ep_R
>
50
:
num_successes
+=
1
print
(
"
\n
Trial {}: success: {}"
.
format
(
trial
+
1
,
num_successes
))
success_list
.
append
(
num_successes
)
print
(
"
\n
Success: Avg: {}, Std: {}"
.
format
(
np
.
mean
(
success_list
),
np
.
std
(
success_list
)))
def
online_mcts
(
nb_episodes
=
10
):
# MCTS visualization is off
# initialize the numpy random number generator
np
.
random
.
seed
()
# load options graph
options
=
OptionsGraph
(
"mcts_config.json"
,
SimpleIntersectionEnv
)
options
.
load_trained_low_level_policies
()
agent
=
DQNLearner
(
input_shape
=
(
50
,
),
nb_actions
=
options
.
get_number_of_nodes
(),
low_level_policies
=
options
.
maneuvers
)
agent
.
load_model
(
"backends/trained_policies/highlevel/highlevel_weights_772.h5f"
)
# set predictor
options
.
set_controller_args
(
predictor
=
agent
.
get_softq_value_using_option_alias
)
# Loop
num_successes
=
0
for
num_ep
in
range
(
nb_episodes
):
init_obs
=
options
.
reset
()
episode_reward
=
0
first_time
=
True
while
not
options
.
env
.
is_terminal
():
if
first_time
:
first_time
=
False
else
:
print
(
'Stepping through ...'
)
features
,
R
,
terminal
,
info
=
options
.
controller
.
\
step_current_node
(
visualize_low_level_steps
=
True
)
episode_reward
+=
R
print
(
'Intermediate Reward: %f (ego x = %f)'
%
(
R
,
options
.
env
.
vehs
[
0
].
x
))
print
(
''
)
if
options
.
controller
.
can_transition
():
options
.
controller
.
do_transition
()
print
(
''
)
print
(
'EPISODE %d: Reward = %f'
%
(
num_ep
,
episode_reward
))
print
(
''
)
print
(
''
)
if
episode_reward
>
50
:
num_successes
+=
1
print
(
"Policy succeeded {} times!"
.
format
(
num_successes
))
def
evaluate_online_mcts
(
nb_episodes
=
20
,
nb_trials
=
5
):
# MCTS visualization is off
# initialize the numpy random number generator
np
.
random
.
seed
()
# load options graph
options
=
OptionsGraph
(
"mcts_config.json"
,
SimpleIntersectionEnv
)
options
.
load_trained_low_level_policies
()
agent
=
DQNLearner
(
input_shape
=
(
50
,
),
nb_actions
=
options
.
get_number_of_nodes
(),
low_level_policies
=
options
.
maneuvers
)
agent
.
load_model
(
"backends/trained_policies/highlevel/highlevel_weights_772.h5f"
)
options
.
set_controller_args
(
predictor
=
agent
.
get_softq_value_using_option_alias
)
predictor
=
agent
.
get_softq_value_using_option_alias
,
max_depth
=
depth
,
nb_traversals
=
nb_traversals
,
debug
=
debug
)
# Evaluate
success_list
=
[]
print
(
"
\n
Conducting {} trials of {} episodes each"
.
format
(
nb_trials
,
nb_episodes
))
success_list
=
[]
termination_reason_list
=
{}
for
trial
in
range
(
nb_trials
):
# Loop
overall_reward_list
=
[]
overall_success_accuracy
=
[]
for
num_tr
in
range
(
nb_trials
):
num_successes
=
0
termination_reason_counter
=
{}
reward_list
=
[]
for
num_ep
in
range
(
nb_episodes
):
init_obs
=
options
.
reset
()
episode_reward
=
0
first_time
=
True
start_time
=
time
.
time
()
while
not
options
.
env
.
is_terminal
():
if
first_time
:
first_time
=
False
else
:
print
(
'Stepping through ...'
)
#
print('Stepping through ...')
features
,
R
,
terminal
,
info
=
options
.
controller
.
\
step_current_node
(
visualize_low_level_steps
=
Tru
e
)
step_current_node
(
visualize_low_level_steps
=
visualiz
e
)
episode_reward
+=
R
print
(
'Intermediate Reward: %f (ego x = %f)'
%
(
R
,
options
.
env
.
vehs
[
0
].
x
))
print
(
''
)
if
terminal
:
if
'episode_termination_reason'
in
info
:
termination_reason
=
info
[
'episode_termination_reason'
]
if
termination_reason
in
termination_reason_counter
:
termination_reason_counter
[
termination_reason
]
+=
1
else
:
termination_reason_counter
[
termination_reason
]
=
1
# print('Intermediate Reward: %f (ego x = %f)' %
# (R, options.env.vehs[0].x))
# print('')
if
options
.
controller
.
can_transition
():
options
.
controller
.
do_transition
()
print
(
''
)
print
(
'EPISODE %d: Reward = %f'
%
(
num_ep
,
episode_reward
))
print
(
''
)
print
(
''
)
if
episode_reward
>
50
:
num_successes
+=
1
print
(
"
\n
Trial {}: success: {}"
.
format
(
trial
+
1
,
num_successes
))
success_list
.
append
(
num_successes
)
for
reason
,
count
in
termination_reason_counter
.
items
():
if
reason
in
termination_reason_list
:
termination_reason_list
[
reason
].
append
(
count
)
else
:
termination_reason_list
[
reason
]
=
[
count
]
success_list
=
np
.
array
(
success_list
)
print
(
"
\n
Success: Avg: {}, Std: {}"
.
format
(
np
.
mean
(
success_list
),
np
.
std
(
success_list
)))
print
(
"Termination reason(s):"
)
for
reason
,
count_list
in
termination_reason_list
.
items
():
count_list
=
np
.
array
(
count_list
)
print
(
"{}: Avg: {}, Std: {}"
.
format
(
reason
,
np
.
mean
(
count_list
),
np
.
std
(
count_list
)))
end_time
=
time
.
time
()
total_time
=
int
(
end_time
-
start_time
)
if
options
.
env
.
goal_achieved
:
num_successes
+=
1
print
(
'Episode {}: Reward = {} ({})'
.
format
(
num_ep
,
episode_reward
,
datetime
.
timedelta
(
seconds
=
total_time
)))
reward_list
+=
[
episode_reward
]
print
(
"Trial {}: Reward = (Avg: {}, Std: {}), Successes: {}/{}"
.
\
format
(
num_tr
,
np
.
mean
(
reward_list
),
np
.
std
(
reward_list
),
\
num_successes
,
nb_episodes
))
overall_reward_list
+=
reward_list
overall_success_accuracy
+=
[
num_successes
*
1.0
/
nb_episodes
]
print
(
'Overall: Reward = (Avg: {}, Std: {}), Success = (Avg: {}, Std: {})'
.
\
format
(
np
.
mean
(
overall_reward_list
),
np
.
std
(
overall_reward_list
),
np
.
mean
(
overall_success_accuracy
),
np
.
std
(
overall_success_accuracy
)))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--train"
,
help
=
"Train an offline mcts with default settings. Always saved in root folder."
,
action
=
"store_true"
)
parser
.
add_argument
(
"--evaluate"
,
help
=
"Evaluate over n trials. "
"Uses backends/trained_policies/mcts/mcts.pickle by default"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--saved_policy_in_root"
,
help
=
"Use saved policies in root of project rather than backends/trained_policies/mcts/"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--load_saved"
,
help
=
"Load a saved policy from root folder first before training"
,
help
=
"Evaluate over n trials, no visualization by default."
,
action
=
"store_true"
)
parser
.
add_argument
(
"--visualize"
,
help
=
"Visualize the training.
Testing is always visualized. Evaluation is not visualized by default
"
,
"Visualize the training."
,
action
=
"store_true"
)
parser
.
add_argument
(
"--
nb_traversals
"
,
help
=
"
Number of traversals to perform. Default is 100
0"
,
default
=
10
00
,
"--
depth
"
,
help
=
"
Max depth of tree per episode. Default is 1
0"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--save_every"
,
help
=
"Saves every n traversals. Saves in root by default. Default is 500"
,
default
=
500
,
"--nb_traversals"
,
help
=
"Number of traversals to perform per episode. Default is 100"
,
default
=
100
,
type
=
int
)
parser
.
add_argument
(
"--nb_
traversals_for_test
"
,
help
=
"Number of episodes
to evaluate. Default is 10
0"
,
default
=
10
0
,
"--nb_
episodes
"
,
help
=
"Number of episodes
per trial to evaluate. Default is 1
0"
,
default
=
10
,
type
=
int
)
parser
.
add_argument
(
"--nb_trials"
,
help
=
"Number of trials to evaluate. Default is 1
0
"
,
default
=
1
0
,
help
=
"Number of trials to evaluate. Default is 1"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--save_file"
,
help
=
"filename to save/load the trained policy. Location is as specified by --saved_policy_in_root. Default name is mcts.pickle"
,
default
=
"mcts.pickle"
)
"--debug"
,
help
=
"Show debug output. Default is false"
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
if
args
.
train
:
mcts_training
(
nb_traversals
=
args
.
nb_traversals
,
save_every
=
args
.
save_every
,
visualize
=
args
.
visualize
,
load_saved
=
args
.
load_saved
,
save_file
=
args
.
save_file
)
if
args
.
evaluate
:
mcts_evaluation
(
nb_traversals
=
args
.
nb_traversals_for_test
,
num_trials
=
args
.
nb_trials
,
depth
=
args
.
depth
,
nb_traversals
=
args
.
nb_traversals
,
nb_episodes
=
args
.
nb_episodes
,
nb_trials
=
args
.
nb_trials
,
visualize
=
args
.
visualize
,
pretrained
=
not
args
.
saved_policy_in_root
,
save_file
=
args
.
save_file
)
debug
=
args
.
debug
)
mcts_config.json
View file @
7a31ba33
{
"nodes"
:
{
"wait"
:
"
MCTS
Wait"
,
"follow"
:
"
MCTS
Follow"
,
"stop"
:
"
MCTS
Stop"
,
"changelane"
:
"
MCTS
ChangeLane"
,
"keeplane"
:
"
MCTS
KeepLane"
"wait"
:
"Wait"
,
"follow"
:
"Follow"
,
"stop"
:
"Stop"
,
"changelane"
:
"ChangeLane"
,
"keeplane"
:
"KeepLane"
},
"edges"
:
{
...
...
options/options_loader.py
View file @
7a31ba33
import
json
import
os
# for the use of os.path.isfile
from
.simple_intersection.maneuvers
import
*
from
.simple_intersection.mcts_maneuvers
import
*
from
backends
import
RLController
,
DDPGLearner
,
MCTSLearner
,
OnlineMCTSController
,
ManualPolicy
from
backends
import
RLController
,
DDPGLearner
,
MCTSController
,
ManualPolicy
class
OptionsGraph
:
"""Represent the options graph as a graph like structure. The configuration
...
...
@@ -67,11 +65,8 @@ class OptionsGraph:
self
.
controller
=
ManualPolicy
(
self
.
env
,
self
.
maneuvers
,
self
.
adj
,
self
.
start_node_alias
)
elif
self
.
config
[
"method"
]
==
"mcts"
:
self
.
controller
=
MCTSLearner
(
self
.
env
,
self
.
maneuvers
,
self
.
start_node_alias
)
elif
self
.
config
[
"method"
]
==
"online_mcts"
:
self
.
controller
=
OnlineMCTSController
(
self
.
env
,
self
.
maneuvers
,
self
.
start_node_alias
)
self
.
controller
=
MCTSController
(
self
.
env
,
self
.
maneuvers
,
self
.
start_node_alias
)
else
:
raise
Exception
(
self
.
__class__
.
__name__
+
\
"Controller to be used not specified"
)
...
...
@@ -156,6 +151,7 @@ class OptionsGraph:
# TODO: error handling
def
load_trained_low_level_policies
(
self
):
for
key
,
maneuver
in
self
.
maneuvers
.
items
():
# TODO: Ensure that for manual policies, nothing is loaded
trained_policy_path
=
"backends/trained_policies/"
+
key
+
"/"
critic_file_exists
=
os
.
path
.
isfile
(
trained_policy_path
+
key
+
"_weights_critic.h5f"
)
actor_file_exists
=
os
.
path
.
isfile
(
trained_policy_path
+
key
+
"_weights_actor.h5f"
)
...
...
@@ -180,9 +176,6 @@ class OptionsGraph:
print
(
"
\n
Warning: the trained low-level policy of
\"
"
+
key
+
"
\"
does not exists; the manual policy will be used.
\n
"
)
if
self
.
config
[
"method"
]
==
"mcts"
:
maneuver
.
timeout
=
np
.
inf
def
get_number_of_nodes
(
self
):
return
len
(
self
.
maneuvers
)
...
...
options/simple_intersection/mcts_maneuvers.py
deleted
100644 → 0
View file @
f25adee3
from
.maneuver_base
import
ManeuverBase
from
env.simple_intersection.constants
import
*
import
env.simple_intersection.road_geokinemetry
as
rd
from
env.simple_intersection.features
import
extract_ego_features
,
extract_other_veh_features
from
verifier.simple_intersection
import
LTLProperty
import
numpy
as
np