Commit 72e44d55 authored by Ashish Gaurav's avatar Ashish Gaurav

format using yapf

parent e1fdb162
...@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase): ...@@ -35,7 +35,9 @@ class PPO2Agent(LearnerBase):
self.log_path = log_path self.log_path = log_path
self.env = DummyVecEnv([lambda: env]) #PPO2 requried a vectorized environment for parallel training self.env = DummyVecEnv([
lambda: env
]) #PPO2 requried a vectorized environment for parallel training
self.agent_model = self.create_agent(policy, tensorboard) self.agent_model = self.create_agent(policy, tensorboard)
def get_default_policy(self): def get_default_policy(self):
...@@ -52,7 +54,8 @@ class PPO2Agent(LearnerBase): ...@@ -52,7 +54,8 @@ class PPO2Agent(LearnerBase):
stable_baselines PPO2 object stable_baselines PPO2 object
""" """
if tensorboard: if tensorboard:
return PPO2(policy, self.env, verbose=1, tensorboard_log=self.log_path) return PPO2(
policy, self.env, verbose=1, tensorboard_log=self.log_path)
else: else:
return PPO2(policy, self.env, verbose=1) return PPO2(policy, self.env, verbose=1)
...@@ -100,7 +103,8 @@ class PPO2Agent(LearnerBase): ...@@ -100,7 +103,8 @@ class PPO2Agent(LearnerBase):
episode_rewards[-1] += rewards[0] episode_rewards[-1] += rewards[0]
if dones[0] or current_step > nb_max_episode_steps: if dones[0] or current_step > nb_max_episode_steps:
obs = self.env.reset() obs = self.env.reset()
print ("Episode ", current_episode, "reward: ", episode_rewards[-1]) print("Episode ", current_episode, "reward: ",
episode_rewards[-1])
episode_rewards.append(0.0) episode_rewards.append(0.0)
current_episode += 1 current_episode += 1
current_step = 0 current_step = 0
......
...@@ -50,11 +50,13 @@ class ControllerBase(PolicyBase): ...@@ -50,11 +50,13 @@ class ControllerBase(PolicyBase):
Returns state at end of node execution, total reward, epsiode_termination_flag, info Returns state at end of node execution, total reward, epsiode_termination_flag, info
''' '''
def step_current_node(self, visualize_low_level_steps=False): def step_current_node(self, visualize_low_level_steps=False):
total_reward = 0 total_reward = 0
self.node_terminal_state_reached = False self.node_terminal_state_reached = False
while not self.node_terminal_state_reached: while not self.node_terminal_state_reached:
observation, reward, terminal, info = self.low_level_step_current_node() observation, reward, terminal, info = self.low_level_step_current_node(
)
if visualize_low_level_steps: if visualize_low_level_steps:
self.env.render() self.env.render()
total_reward += reward total_reward += reward
...@@ -70,9 +72,12 @@ class ControllerBase(PolicyBase): ...@@ -70,9 +72,12 @@ class ControllerBase(PolicyBase):
Returns state after one step, step reward, episode_termination_flag, info Returns state after one step, step reward, episode_termination_flag, info
''' '''
def low_level_step_current_node(self): def low_level_step_current_node(self):
u_ego = self.current_node.low_level_policy(self.current_node.get_reduced_features_tuple()) u_ego = self.current_node.low_level_policy(
self.current_node.get_reduced_features_tuple())
feature, R, terminal, info = self.current_node.step(u_ego) feature, R, terminal, info = self.current_node.step(u_ego)
self.node_terminal_state_reached = terminal self.node_terminal_state_reached = terminal
return self.env.get_features_tuple(), R, self.env.termination_condition, info return self.env.get_features_tuple(
), R, self.env.termination_condition, info
...@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase): ...@@ -47,7 +47,8 @@ class DDPGLearner(LearnerBase):
"oup_mu": 0, # OrnsteinUhlenbeckProcess mu "oup_mu": 0, # OrnsteinUhlenbeckProcess mu
"oup_sigma": 1, # OrnsteinUhlenbeckProcess sigma "oup_sigma": 1, # OrnsteinUhlenbeckProcess sigma
"oup_sigma_min": 0.5, # OrnsteinUhlenbeckProcess sigma min "oup_sigma_min": 0.5, # OrnsteinUhlenbeckProcess sigma min
"oup_annealing_steps": 500000, # OrnsteinUhlenbeckProcess n-step annealing "oup_annealing_steps":
500000, # OrnsteinUhlenbeckProcess n-step annealing
"nb_steps_warmup_critic": 100, # steps for critic to warmup "nb_steps_warmup_critic": 100, # steps for critic to warmup
"nb_steps_warmup_actor": 100, # steps for actor to warmup "nb_steps_warmup_actor": 100, # steps for actor to warmup
"target_model_update": 1e-3 # target model update frequency "target_model_update": 1e-3 # target model update frequency
...@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase): ...@@ -160,24 +161,33 @@ class DDPGLearner(LearnerBase):
target_model_update=1e-3) target_model_update=1e-3)
# TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic. # TODO: give params like lr_actor and lr_critic to set different lr of Actor and Critic.
agent.compile([Adam(lr=self.lr*1e-2, clipnorm=1.), Adam(lr=self.lr, clipnorm=1.)], metrics=['mae']) agent.compile(
[
Adam(lr=self.lr * 1e-2, clipnorm=1.),
Adam(lr=self.lr, clipnorm=1.)
],
metrics=['mae'])
return agent return agent
def train(self, def train(self,
env, env,
nb_steps=1000000, nb_steps=1000000,
visualize=False, visualize=False,
verbose=1, verbose=1,
log_interval=10000, log_interval=10000,
nb_max_episode_steps=200, nb_max_episode_steps=200,
model_checkpoints=False, model_checkpoints=False,
checkpoint_interval=100000, checkpoint_interval=100000,
tensorboard=False): tensorboard=False):
callbacks = [] callbacks = []
if model_checkpoints: if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)] callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard: if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')] callbacks += [TensorBoard(log_dir='./logs')]
...@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase): ...@@ -291,28 +301,36 @@ class DQNLearner(LearnerBase):
Returns: Returns:
KerasRL DQN object KerasRL DQN object
""" """
agent = DQNAgentOverOptions(model=model, low_level_policies=self.low_level_policies, agent = DQNAgentOverOptions(
nb_actions=self.nb_actions, memory=memory, model=model,
nb_steps_warmup=self.nb_steps_warmup, target_model_update=self.target_model_update, low_level_policies=self.low_level_policies,
policy=policy, enable_dueling_network=True) nb_actions=self.nb_actions,
memory=memory,
nb_steps_warmup=self.nb_steps_warmup,
target_model_update=self.target_model_update,
policy=policy,
enable_dueling_network=True)
agent.compile(Adam(lr=self.lr), metrics=['mae']) agent.compile(Adam(lr=self.lr), metrics=['mae'])
return agent return agent
def train(self, def train(self,
env, env,
nb_steps=1000000, nb_steps=1000000,
visualize=False, visualize=False,
nb_max_episode_steps=200, nb_max_episode_steps=200,
tensorboard=False, tensorboard=False,
model_checkpoints=False, model_checkpoints=False,
checkpoint_interval=10000): checkpoint_interval=10000):
callbacks = [] callbacks = []
if model_checkpoints: if model_checkpoints:
callbacks += [ModelIntervalCheckpoint('./checkpoints/checkpoint_weights.h5f', interval=checkpoint_interval)] callbacks += [
ModelIntervalCheckpoint(
'./checkpoints/checkpoint_weights.h5f',
interval=checkpoint_interval)
]
if tensorboard: if tensorboard:
callbacks += [TensorBoard(log_dir='./logs')] callbacks += [TensorBoard(log_dir='./logs')]
...@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase): ...@@ -333,7 +351,7 @@ class DQNLearner(LearnerBase):
nb_episodes=5, nb_episodes=5,
visualize=True, visualize=True,
nb_max_episode_steps=400, nb_max_episode_steps=400,
success_reward_threshold = 100): success_reward_threshold=100):
print("Testing for {} episodes".format(nb_episodes)) print("Testing for {} episodes".format(nb_episodes))
success_count = 0 success_count = 0
...@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase): ...@@ -359,13 +377,14 @@ class DQNLearner(LearnerBase):
env.reset() env.reset()
if episode_reward >= success_reward_threshold: if episode_reward >= success_reward_threshold:
success_count += 1 success_count += 1
print("Episode {}: steps:{}, reward:{}".format(n+1, step, episode_reward)) print("Episode {}: steps:{}, reward:{}".format(
n + 1, step, episode_reward))
print ("\nPolicy succeeded {} times!".format(success_count)) print("\nPolicy succeeded {} times!".format(success_count))
print ("Failures due to:") print("Failures due to:")
print (termination_reason_counter) print(termination_reason_counter)
return [success_count,termination_reason_counter] return [success_count, termination_reason_counter]
def load_model(self, file_name="test_weights.h5f"): def load_model(self, file_name="test_weights.h5f"):
self.agent_model.load_weights(file_name) self.agent_model.load_weights(file_name)
...@@ -377,27 +396,39 @@ class DQNLearner(LearnerBase): ...@@ -377,27 +396,39 @@ class DQNLearner(LearnerBase):
return self.agent_model.get_modified_q_values(observation)[action] return self.agent_model.get_modified_q_values(observation)[action]
def get_q_value_using_option_alias(self, observation, option_alias): def get_q_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias) action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
return self.agent_model.get_modified_q_values(observation)[action_num] return self.agent_model.get_modified_q_values(observation)[action_num]
def get_softq_value_using_option_alias(self, observation, option_alias): def get_softq_value_using_option_alias(self, observation, option_alias):
action_num = self.agent_model.low_level_policy_aliases.index(option_alias) action_num = self.agent_model.low_level_policy_aliases.index(
option_alias)
q_values = self.agent_model.get_modified_q_values(observation) q_values = self.agent_model.get_modified_q_values(observation)
max_q_value = np.abs(np.max(q_values)) max_q_value = np.abs(np.max(q_values))
q_values = [np.exp(q_value/max_q_value) for q_value in q_values] q_values = [np.exp(q_value / max_q_value) for q_value in q_values]
relevant = q_values[action_num]/np.sum(q_values) relevant = q_values[action_num] / np.sum(q_values)
return relevant return relevant
class DQNAgentOverOptions(DQNAgent):
def __init__(self, model, low_level_policies, policy=None, test_policy=None, enable_double_dqn=True, enable_dueling_network=False, class DQNAgentOverOptions(DQNAgent):
dueling_type='avg', *args, **kwargs): def __init__(self,
super(DQNAgentOverOptions, self).__init__(model, policy, test_policy, enable_double_dqn, enable_dueling_network, model,
dueling_type, *args, **kwargs) low_level_policies,
policy=None,
test_policy=None,
enable_double_dqn=True,
enable_dueling_network=False,
dueling_type='avg',
*args,
**kwargs):
super(DQNAgentOverOptions, self).__init__(
model, policy, test_policy, enable_double_dqn,
enable_dueling_network, dueling_type, *args, **kwargs)
self.low_level_policies = low_level_policies self.low_level_policies = low_level_policies
if low_level_policies is not None: if low_level_policies is not None:
self.low_level_policy_aliases = list(self.low_level_policies.keys()) self.low_level_policy_aliases = list(
self.low_level_policies.keys())
def __get_invalid_node_indices(self): def __get_invalid_node_indices(self):
"""Returns a list of option indices that are invalid according to initiation conditions. """Returns a list of option indices that are invalid according to initiation conditions.
...@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent): ...@@ -435,5 +466,3 @@ class DQNAgentOverOptions(DQNAgent):
q_values[node_index] = -np.inf q_values[node_index] = -np.inf
return q_values return q_values
from.policy_base import PolicyBase from .policy_base import PolicyBase
import numpy as np import numpy as np
...@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase): ...@@ -23,10 +23,10 @@ class LearnerBase(PolicyBase):
setattr(self, prop, kwargs.get(prop, default)) setattr(self, prop, kwargs.get(prop, default))
def train(self, def train(self,
env, env,
nb_steps=50000, nb_steps=50000,
visualize=False, visualize=False,
nb_max_episode_steps=200): nb_max_episode_steps=200):
"""Train the learning agent on the environment. """Train the learning agent on the environment.
Args: Args:
......
from .controller_base import ControllerBase from .controller_base import ControllerBase
class ManualPolicy(ControllerBase): class ManualPolicy(ControllerBase):
"""Manual policy execution using nodes and edges.""" """Manual policy execution using nodes and edges."""
def __init__(self, env, low_level_policies, transition_adj, start_node_alias): def __init__(self, env, low_level_policies, transition_adj,
start_node_alias):
"""Constructor for manual policy execution. """Constructor for manual policy execution.
Args: Args:
...@@ -13,7 +15,8 @@ class ManualPolicy(ControllerBase): ...@@ -13,7 +15,8 @@ class ManualPolicy(ControllerBase):
start_node: starting node start_node: starting node
""" """
super(ManualPolicy, self).__init__(env, low_level_policies, start_node_alias) super(ManualPolicy, self).__init__(env, low_level_policies,
start_node_alias)
self.adj = transition_adj self.adj = transition_adj
def _transition(self): def _transition(self):
...@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase): ...@@ -50,4 +53,4 @@ class ManualPolicy(ControllerBase):
new_node = self._transition() new_node = self._transition()
if new_node is not None: if new_node is not None:
self.current_node = new_node self.current_node = new_node
\ No newline at end of file
...@@ -2,6 +2,7 @@ from .controller_base import ControllerBase ...@@ -2,6 +2,7 @@ from .controller_base import ControllerBase
import numpy as np import numpy as np
import pickle import pickle
class MCTSLearner(ControllerBase): class MCTSLearner(ControllerBase):
"""Monte Carlo Tree Search implementation using the UCB1 and """Monte Carlo Tree Search implementation using the UCB1 and
progressive widening approach as explained in Paxton et al (2017). progressive widening approach as explained in Paxton et al (2017).
...@@ -9,8 +10,8 @@ class MCTSLearner(ControllerBase): ...@@ -9,8 +10,8 @@ class MCTSLearner(ControllerBase):
_ucb_vals = set() _ucb_vals = set()
def __init__(self, env, low_level_policies, def __init__(self, env, low_level_policies, start_node_alias,
start_node_alias, max_depth=10): max_depth=10):
"""Constructor for MCTSLearner. """Constructor for MCTSLearner.
Args: Args:
...@@ -22,10 +23,12 @@ class MCTSLearner(ControllerBase): ...@@ -22,10 +23,12 @@ class MCTSLearner(ControllerBase):
max_depth: max depth of the MCTS tree; default 10 levels max_depth: max depth of the MCTS tree; default 10 levels
""" """
super(MCTSLearner, self).__init__(env, low_level_policies, start_node_alias) super(MCTSLearner, self).__init__(env, low_level_policies,
start_node_alias)
self.controller_args_defaults = { self.controller_args_defaults = {
"predictor": None #P(s, o) learner class; forward pass should return the entire value from state s and option o "predictor":
None #P(s, o) learner class; forward pass should return the entire value from state s and option o
} }
self.max_depth = max_depth self.max_depth = max_depth
#: store current node alias #: store current node alias
...@@ -48,11 +51,17 @@ class MCTSLearner(ControllerBase): ...@@ -48,11 +51,17 @@ class MCTSLearner(ControllerBase):
# populate root node # populate root node
root_node_num, root_node_info = self._create_node(self.curr_node_alias) root_node_num, root_node_info = self._create_node(self.curr_node_alias)
self.nodes[root_node_num] = root_node_info self.nodes[root_node_num] = root_node_info
self.adj[root_node_num] = set() # no children self.adj[root_node_num] = set() # no children
def save_model(self, file_name="mcts.pickle"): def save_model(self, file_name="mcts.pickle"):
to_backup = {'N': self.N, 'M': self.M, 'TR': self.TR, 'nodes': self.nodes, to_backup = {
'adj': self.adj, 'new_node_num': self.new_node_num} 'N': self.N,
'M': self.M,
'TR': self.TR,
'nodes': self.nodes,
'adj': self.adj,
'new_node_num': self.new_node_num
}
with open(file_name, 'wb') as handle: with open(file_name, 'wb') as handle:
pickle.dump(to_backup, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(to_backup, handle, protocol=pickle.HIGHEST_PROTOCOL)
...@@ -96,9 +105,9 @@ class MCTSLearner(ControllerBase): ...@@ -96,9 +105,9 @@ class MCTSLearner(ControllerBase):
""" """
dis_observation = '' dis_observation = ''
for item in observation[12:20]: for item in observation[12:20]:
if type(item)==bool: if type(item) == bool:
dis_observation += '1' if item is True else '0' dis_observation += '1' if item is True else '0'
if type(item)==int and item in [0, 1]: if type(item) == int and item in [0, 1]:
dis_observation += str(item) dis_observation += str(item)
env = self.current_node.env env = self.current_node.env
...@@ -163,7 +172,8 @@ class MCTSLearner(ControllerBase): ...@@ -163,7 +172,8 @@ class MCTSLearner(ControllerBase):
dis_observation = self._to_discrete(observation) dis_observation = self._to_discrete(observation)
if (dis_observation, option) not in self.TR: if (dis_observation, option) not in self.TR:
self.TR[(dis_observation, option)] = 0 self.TR[(dis_observation, option)] = 0
return self.TR[(dis_observation, option)] / (1+self._get_visitation_count(observation, option)) return self.TR[(dis_observation, option)] / (
1 + self._get_visitation_count(observation, option))
def _get_possible_options(self): def _get_possible_options(self):
"""Returns a set of options that can be taken from the current node. """Returns a set of options that can be taken from the current node.
...@@ -172,7 +182,7 @@ class MCTSLearner(ControllerBase): ...@@ -172,7 +182,7 @@ class MCTSLearner(ControllerBase):
""" """
all_options = set(self.low_level_policies.keys()) all_options = set(self.low_level_policies.keys())
# Filter nodes whose initiation condition are true # Filter nodes whose initiation condition are true
filtered_options = set() filtered_options = set()
for option_alias in all_options: for option_alias in all_options:
...@@ -211,14 +221,15 @@ class MCTSLearner(ControllerBase): ...@@ -211,14 +221,15 @@ class MCTSLearner(ControllerBase):
Returns Q values for next nodes Returns Q values for next nodes
""" """
Q = {} Q = {}
Q1, Q2 = {}, {} # debug Q1, Q2 = {}, {} # debug
dis_observation = self._to_discrete(observation) dis_observation = self._to_discrete(observation)
next_option_nums = self.adj[self.curr_node_num] next_option_nums = self.adj[self.curr_node_num]
for next_option_num in next_option_nums: for next_option_num in next_option_nums:
next_option = self.nodes[next_option_num]["policy"] next_option = self.nodes[next_option_num]["policy"]
Q1[next_option] = (self._get_q_star(observation, next_option)+200)/400 Q1[next_option] = (
self._get_q_star(observation, next_option) + 200) / 400
Q[(dis_observation, next_option)] = \ Q[(dis_observation, next_option)] = \
Q1[next_option] Q1[next_option]
Q2[next_option] = C * \ Q2[next_option] = C * \
...@@ -243,7 +254,7 @@ class MCTSLearner(ControllerBase): ...@@ -243,7 +254,7 @@ class MCTSLearner(ControllerBase):
relevant_rewards = [value for key, value in self.TR.items() \ relevant_rewards = [value for key, value in self.TR.items() \
if key[0] == dis_observation] if key[0] == dis_observation]
sum_rewards = np.sum(relevant_rewards) sum_rewards = np.sum(relevant_rewards)
return sum_rewards / (1+self._get_visitation_count(observation)) return sum_rewards / (1 + self._get_visitation_count(observation))
def _select(self, observation, depth=0, visualize=False): def _select(self, observation, depth=0, visualize=False):
"""MCTS selection function. For representation, we only use """MCTS selection function. For representation, we only use
...@@ -266,7 +277,8 @@ class MCTSLearner(ControllerBase): ...@@ -266,7 +277,8 @@ class MCTSLearner(ControllerBase):
if is_terminal or max_depth_reached: if is_terminal or max_depth_reached:
# print('MCTS went %d nodes deep' % depth) # print('MCTS went %d nodes deep' % depth)
return self._value(observation), 0 # TODO: replace with final goal reward return self._value(
observation), 0 # TODO: replace with final goal reward
Ns = self._get_visitation_count(observation) Ns = self._get_visitation_count(observation)
Nchildren = len(self.adj[self.curr_node_num]) Nchildren = len(self.adj[self.curr_node_num])
...@@ -288,19 +300,20 @@ class MCTSLearner(ControllerBase): ...@@ -288,19 +300,20 @@ class MCTSLearner(ControllerBase):
self.adj[self.curr_node_num].add(new_node_num) self.adj[self.curr_node_num].add(new_node_num)
# Find o_star and do a transition, i.e. update curr_node # Find o_star and do a transition, i.e. update curr_node
# Simulate / lookup; first change next # Simulate / lookup; first change next
next_observation, episode_R, o_star = self.do_transition(observation, next_observation, episode_R, o_star = self.do_transition(
visualize=visualize) observation, visualize=visualize)
# Recursively select next node # Recursively select next node
remaining_v, all_ep_R = self._select(next_observation, depth+1, visualize=visualize) remaining_v, all_ep_R = self._select(
next_observation, depth + 1, visualize=visualize)
# Update values # Update values
self.N[dis_observation] += 1 self.N[dis_observation] += 1
self.M[(dis_observation, o_star)] += 1 self.M[(dis_observation, o_star)] += 1
self.TR[(dis_observation, o_star)] += (episode_R + remaining_v) self.TR[(dis_observation, o_star)] += (episode_R + remaining_v)
return self._value(observation), all_ep_R+episode_R return self._value(observation), all_ep_R + episode_R
def traverse(self, observation, visualize=False): def traverse(self, observation, visualize=False):
"""Do a complete traversal from root to leaf. Assumes the """Do a complete traversal from root to leaf. Assumes the
...@@ -368,4 +381,4 @@ class MCTSLearner(ControllerBase): ...@@ -368,4 +381,4 @@ class MCTSLearner(ControllerBase):
next_keys, next_values = list(Q1.keys()), list(Q1.values()) next_keys, next_values = list(Q1.keys()), list(Q1.values())
o_star = next_keys[np.argmax(next_values)] o_star = next_keys[np.argmax(next_values)]
print(Q1) print(Q1)
return o_star return o_star
\ No newline at end of file
...@@ -3,6 +3,7 @@ from .mcts_learner import MCTSLearner ...@@ -3,6 +3,7 @@ from .mcts_learner import MCTSLearner
import tqdm import tqdm
import numpy as np import numpy as np
class OnlineMCTSController(ControllerBase): class OnlineMCTSController(ControllerBase):
"""Online MCTS""" """Online MCTS"""
...@@ -13,12 +14,13 @@ class OnlineMCTSController(ControllerBase): ...@@ -13,12 +14,13 @@ class OnlineMCTSController(ControllerBase):
env: env instance env: env instance
low_level_policies: low level policies dictionary low_level_policies: low level policies dictionary
""" """
super(OnlineMCTSController, self).__init__(env, low_level_policies, start_node_alias) super(OnlineMCTSController, self).__init__(env, low_level_policies,
start_node_alias)
self.curr_node_alias = start_node_alias self.curr_node_alias = start_node_alias
self.controller_args_defaults = { self.controller_args_defaults = {
"predictor": None, "predictor": None,
"max_depth": 5, # MCTS depth "max_depth": 5, # MCTS depth
"nb_traversals": 30, # MCTS traversals before decision "nb_traversals": 30, # MCTS traversals before decision
} }
def set_current_node(self, node_alias): def set_current_node(self, node_alias):
...@@ -48,12 +50,14 @@ class OnlineMCTSController(ControllerBase): ...@@ -48,12 +50,14 @@ class OnlineMCTSController(ControllerBase):
env_before_mcts = orig_env.copy() env_before_mcts = orig_env.copy()
self.change_low_level_references(env_before_mcts) self.change_low_level_references(env_before_mcts)
print('Current Node: %s' % self.curr_node_alias) print('Current Node: %s' % self.curr_node_alias)
mcts = MCTSLearner(self.env, self.low_level_policies, self.curr_node_alias) mcts = MCTSLearner(self.env, self.low_level_policies,
self.curr_node_alias)
mcts.max_depth = self.max_depth mcts.max_depth = self.max_depth
mcts.set_controller_args(predictor = self.predictor) mcts.set_controller_args(predictor=self.predictor)
# Do nb_traversals number of traversals, reset env to this point every time # Do nb_traversals number of traversals, reset env to this point every time
# print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals)) # print('Doing MCTS with params: max_depth = %d, nb_traversals = %d' % (self.max_depth, self.nb_traversals))
for num_epoch in range(self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)): for num_epoch in range(
self.nb_traversals): # tqdm.tqdm(range(self.nb_traversals)):
mcts.curr_node_num = 0 mcts.curr_node_num = 0
env_begin_epoch = env_before_mcts.copy() env_begin_epoch = env_before_mcts.copy()
self.change_low_level_references(env_begin_epoch) self.change_low_level_references(env_begin_epoch)
...@@ -63,6 +67,7 @@ class OnlineMCTSController(ControllerBase): ...@@ -63,6 +67,7 @@ class OnlineMCTSController(ControllerBase):
# Find the nodes from the root node # Find the nodes from the root node
mcts.curr_node_num = 0 mcts.curr_node_num = 0
print('%s' % mcts._to_discrete(self.env.get_features_tuple())) print('%s' % mcts._to_discrete(self.env.get_features_tuple()))
node_after_transition = mcts.get_best_node(self.env.get_features_tuple(), use_ucb=False) node_after_transition = mcts.get_best_node(
self.env.get_features_tuple(), use_ucb=False)
print('MCTS suggested next option: %s' % node_after_transition) print('MCTS suggested next option: %s' % node_after_transition)
self.set_current_node(node_after_transition) self.set_current_node(node_after_transition)
\ No newline at end of file
class PolicyBase: class PolicyBase:
"""Abstract policy base from which every policy backend is defined """Abstract policy base from which every policy backend is defined
and inherited.""" and inherited."""
\ No newline at end of file
...@@ -11,7 +11,8 @@ class RLController(ControllerBase): ...@@ -11,7 +11,8 @@ class RLController(ControllerBase):
env: env instance env: env instance
low_level_policies: low level policies dictionary low_level_policies: low level policies dictionary
""" """
super(RLController, self).__init__(env, low_level_policies, start_node_alias) super(RLController, self).__init__(env, low_level_policies,
start_node_alias)
self.low_level_policy_aliases = list(self.low_level_policies.keys()) self.low_level_policy_aliases = list(self.low_level_policies.keys())
self.trained_policy = None