diff --git a/backends/mcts_controller.py b/backends/mcts_controller.py
index 619a2260a56cf433da00c4b4141fbac22c200df8..2644419b9fe8f9ba9a85969e7e9bd2069d12b123 100644
--- a/backends/mcts_controller.py
+++ b/backends/mcts_controller.py
@@ -21,6 +21,7 @@ class MCTSController(ControllerBase):
             "max_depth": 10,  # MCTS depth
             "nb_traversals": 100,  # MCTS traversals before decision
             "debug": False,
+            "rollout_timeout": 500
         }
 
     def set_current_node(self, node_alias):
@@ -65,7 +66,8 @@ class MCTSController(ControllerBase):
         if not hasattr(self, 'mcts'):
             if self.debug:
                 print('Creating MCTS Tree: max depth {}'.format(self.max_depth))
-            self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth, debug=self.debug)
+            self.mcts = MCTSLearner(self.env, self.low_level_policies, max_depth=self.max_depth,
+                                    debug=self.debug, rollout_timeout=self.rollout_timeout)
             self.mcts.set_controller_args(predictor=self.predictor)
         if self.debug: 
             print('')
diff --git a/backends/mcts_learner.py b/backends/mcts_learner.py
index 0248d6ae48e476c51197110b08c87d638b723d69..aa8ca4835293979afd4a1da043c4b662b558a13e 100644
--- a/backends/mcts_learner.py
+++ b/backends/mcts_learner.py
@@ -116,7 +116,8 @@ class Tree:
 class MCTSLearner(ControllerBase):
     """MCTS Logic."""
 
-    def __init__(self, env, low_level_policies, max_depth=10, debug=False):
+    def __init__(self, env, low_level_policies, max_depth=10, debug=False,
+        rollout_timeout=500):
         """Constructor for MCTSLearner.
 
         Args:
@@ -124,6 +125,7 @@ class MCTSLearner(ControllerBase):
             low_level_policies: given low level maneuvers
             max_depth: the tree's max depth
             debug: whether or not to print debug statements
+            rollout_timeout: timeout for the rollout
         """
 
         self.env = env # super?
@@ -131,6 +133,7 @@ class MCTSLearner(ControllerBase):
         self.controller_args_defaults = {"predictor": None}
         self.tree = Tree(max_depth=max_depth)
         self.debug = debug
+        self.rollout_timeout = rollout_timeout
 
     def reset(self):
         """Resets maneuvers and sets current node to root."""
@@ -249,11 +252,15 @@ class MCTSLearner(ControllerBase):
         # print('Reached depth %d' % self.tree.nodes[self.tree.curr_node_num].depth, end=' ')
         # print('at node: %d, reached leaf: %s, terminated: %s' % (self.tree.curr_node_num, reached_leaf, self.env.is_terminal()))
         if reached_leaf:
-            rollout_reward = self.def_policy() # from leaf node
+            rollout_reward, timed_out = self.def_policy() # from leaf node
             if rollout_reward > 0:
                 self.backup(1.0) # from leaf node
                 success = 1
-            elif rollout_reward < -150:
+            elif rollout_reward < -150 or timed_out:
+                # TODO: -150 is arbitrary. It should be set from outside or
+                # provided through a variable, based on the env. Same for timeout,
+                # it is specific to the env. Also, for smaller timeouts it
+                # may be a good idea to propagate 0 instead of -1.
                 self.backup(-1.0)
             else:
                 self.backup(0)
@@ -353,7 +360,7 @@ class MCTSLearner(ControllerBase):
         rollout_reward = 0
         obs = self.tree.latest_obs
         it = 0
-        while not self.env.is_terminal():
+        while (not self.env.is_terminal()) and it < self.rollout_timeout:
             it += 1
             possible_options = self._get_possible_options()
             # print('possible is %s' % possible_options)
@@ -376,9 +383,10 @@ class MCTSLearner(ControllerBase):
             if eps_R != None:
                 rollout_reward += eps_R
         # print('Rollout steps = %d' % it)
+        timed_out = (it < self.rollout_timeout)
         if self.debug:
             print(' <<%g>>' % rollout_reward, end='  ')
-        return rollout_reward
+        return rollout_reward, timed_out
 
     def backup(self, rollout_reward):
         """Reward backup strategy.