Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
wise-move
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
wise-lab
wise-move
Commits
051c8502
Commit
051c8502
authored
6 years ago
by
Ashish Gaurav
Browse files
Options
Downloads
Patches
Plain Diff
add rollout timeout for MCTS
parent
cb171a91
No related branches found
No related tags found
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
backends/mcts_controller.py
+3
-1
3 additions, 1 deletion
backends/mcts_controller.py
backends/mcts_learner.py
+13
-5
13 additions, 5 deletions
backends/mcts_learner.py
with
16 additions
and
6 deletions
backends/mcts_controller.py
+
3
−
1
View file @
051c8502
...
...
@@ -21,6 +21,7 @@ class MCTSController(ControllerBase):
"
max_depth
"
:
10
,
# MCTS depth
"
nb_traversals
"
:
100
,
# MCTS traversals before decision
"
debug
"
:
False
,
"
rollout_timeout
"
:
500
}
def
set_current_node
(
self
,
node_alias
):
...
...
@@ -65,7 +66,8 @@ class MCTSController(ControllerBase):
if
not
hasattr
(
self
,
'
mcts
'
):
if
self
.
debug
:
print
(
'
Creating MCTS Tree: max depth {}
'
.
format
(
self
.
max_depth
))
self
.
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,
max_depth
=
self
.
max_depth
,
debug
=
self
.
debug
)
self
.
mcts
=
MCTSLearner
(
self
.
env
,
self
.
low_level_policies
,
max_depth
=
self
.
max_depth
,
debug
=
self
.
debug
,
rollout_timeout
=
self
.
rollout_timeout
)
self
.
mcts
.
set_controller_args
(
predictor
=
self
.
predictor
)
if
self
.
debug
:
print
(
''
)
...
...
This diff is collapsed.
Click to expand it.
backends/mcts_learner.py
+
13
−
5
View file @
051c8502
...
...
@@ -116,7 +116,8 @@ class Tree:
class
MCTSLearner
(
ControllerBase
):
"""
MCTS Logic.
"""
def
__init__
(
self
,
env
,
low_level_policies
,
max_depth
=
10
,
debug
=
False
):
def
__init__
(
self
,
env
,
low_level_policies
,
max_depth
=
10
,
debug
=
False
,
rollout_timeout
=
500
):
"""
Constructor for MCTSLearner.
Args:
...
...
@@ -124,6 +125,7 @@ class MCTSLearner(ControllerBase):
low_level_policies: given low level maneuvers
max_depth: the tree
'
s max depth
debug: whether or not to print debug statements
rollout_timeout: timeout for the rollout
"""
self
.
env
=
env
# super?
...
...
@@ -131,6 +133,7 @@ class MCTSLearner(ControllerBase):
self
.
controller_args_defaults
=
{
"
predictor
"
:
None
}
self
.
tree
=
Tree
(
max_depth
=
max_depth
)
self
.
debug
=
debug
self
.
rollout_timeout
=
rollout_timeout
def
reset
(
self
):
"""
Resets maneuvers and sets current node to root.
"""
...
...
@@ -249,11 +252,15 @@ class MCTSLearner(ControllerBase):
# print('Reached depth %d' % self.tree.nodes[self.tree.curr_node_num].depth, end=' ')
# print('at node: %d, reached leaf: %s, terminated: %s' % (self.tree.curr_node_num, reached_leaf, self.env.is_terminal()))
if
reached_leaf
:
rollout_reward
=
self
.
def_policy
()
# from leaf node
rollout_reward
,
timed_out
=
self
.
def_policy
()
# from leaf node
if
rollout_reward
>
0
:
self
.
backup
(
1.0
)
# from leaf node
success
=
1
elif
rollout_reward
<
-
150
:
elif
rollout_reward
<
-
150
or
timed_out
:
# TODO: -150 is arbitrary. It should be set from outside or
# provided through a variable, based on the env. Same for timeout,
# it is specific to the env. Also, for smaller timeouts it
# may be a good idea to propagate 0 instead of -1.
self
.
backup
(
-
1.0
)
else
:
self
.
backup
(
0
)
...
...
@@ -353,7 +360,7 @@ class MCTSLearner(ControllerBase):
rollout_reward
=
0
obs
=
self
.
tree
.
latest_obs
it
=
0
while
not
self
.
env
.
is_terminal
():
while
(
not
self
.
env
.
is_terminal
()
)
and
it
<
self
.
rollout_timeout
:
it
+=
1
possible_options
=
self
.
_get_possible_options
()
# print('possible is %s' % possible_options)
...
...
@@ -376,9 +383,10 @@ class MCTSLearner(ControllerBase):
if
eps_R
!=
None
:
rollout_reward
+=
eps_R
# print('Rollout steps = %d' % it)
timed_out
=
(
it
<
self
.
rollout_timeout
)
if
self
.
debug
:
print
(
'
<<%g>>
'
%
rollout_reward
,
end
=
'
'
)
return
rollout_reward
return
rollout_reward
,
timed_out
def
backup
(
self
,
rollout_reward
):
"""
Reward backup strategy.
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment