Commit 4a9327bd authored by Jae Young Lee's avatar Jae Young Lee

Improve and Bug-fix DQNLearner and environments.

- Added RestrictedEpsGreedyPolicy and RestrictedGreedyPolicy and use them as policy and test_policy in DQNLearner. Now, the agent never chooses the action corresponding to -inf Q-value if there is at least one action with finite Q-value (if not, it chooses any action randomly, which is necessary for compatibility with keras-rl --
 see the comments in select_action).

- Now, generate_scenario in SimpleIntersectionEnv generates veh_ahead_scenario even when randomize_special_scenario = 1.

- In EpisodicEnvBase, the terminal reward is by default determined by the minimum one;

- Small change of initiation_condition of EpisodicEnvBase (simplified);
parent d0b74b00
...@@ -8,7 +8,7 @@ from keras.callbacks import TensorBoard ...@@ -8,7 +8,7 @@ from keras.callbacks import TensorBoard
from rl.agents import DDPGAgent, DQNAgent from rl.agents import DDPGAgent, DQNAgent
from rl.memory import SequentialMemory from rl.memory import SequentialMemory
from rl.random import OrnsteinUhlenbeckProcess from rl.random import OrnsteinUhlenbeckProcess
from rl.policy import BoltzmannQPolicy, MaxBoltzmannQPolicy from rl.policy import GreedyQPolicy, EpsGreedyQPolicy, MaxBoltzmannQPolicy
from rl.callbacks import ModelIntervalCheckpoint from rl.callbacks import ModelIntervalCheckpoint
...@@ -229,6 +229,7 @@ class DQNLearner(LearnerBase): ...@@ -229,6 +229,7 @@ class DQNLearner(LearnerBase):
model=None, model=None,
policy=None, policy=None,
memory=None, memory=None,
test_policy=None,
**kwargs): **kwargs):
"""The constructor which sets the properties of the class. """The constructor which sets the properties of the class.
...@@ -236,8 +237,8 @@ class DQNLearner(LearnerBase): ...@@ -236,8 +237,8 @@ class DQNLearner(LearnerBase):
input_shape: Shape of observation space, e.g (10,); input_shape: Shape of observation space, e.g (10,);
nb_actions: number of values in action space; nb_actions: number of values in action space;
model: Keras Model of actor which takes observation as input and outputs actions. Uses default if not given model: Keras Model of actor which takes observation as input and outputs actions. Uses default if not given
policy: KerasRL Policy. Uses default SequentialMemory if not given policy: KerasRL Policy. Uses default MaxBoltzmannQPolicy if not given
memory: KerasRL Memory. Uses default BoltzmannQPolicy if not given memory: KerasRL Memory. Uses default SequentialMemory if not given
**kwargs: other optional key-value arguments with defaults defined in property_defaults **kwargs: other optional key-value arguments with defaults defined in property_defaults
""" """
super(DQNLearner, self).__init__(input_shape, nb_actions, **kwargs) super(DQNLearner, self).__init__(input_shape, nb_actions, **kwargs)
...@@ -255,12 +256,14 @@ class DQNLearner(LearnerBase): ...@@ -255,12 +256,14 @@ class DQNLearner(LearnerBase):
model = self.get_default_model() model = self.get_default_model()
if policy is None: if policy is None:
policy = self.get_default_policy() policy = self.get_default_policy()
if test_policy is None:
test_policy = self.get_default_test_policy()
if memory is None: if memory is None:
memory = self.get_default_memory() memory = self.get_default_memory()
self.low_level_policies = low_level_policies self.low_level_policies = low_level_policies
self.agent_model = self.create_agent(model, policy, memory) self.agent_model = self.create_agent(model, policy, memory, test_policy)
def get_default_model(self): def get_default_model(self):
"""Creates the default model. """Creates the default model.
...@@ -269,7 +272,6 @@ class DQNLearner(LearnerBase): ...@@ -269,7 +272,6 @@ class DQNLearner(LearnerBase):
""" """
model = Sequential() model = Sequential()
model.add(Flatten(input_shape=(1, ) + self.input_shape)) model.add(Flatten(input_shape=(1, ) + self.input_shape))
#model.add(Dense(64))
model.add(Dense(64)) model.add(Dense(64))
model.add(Activation('relu')) model.add(Activation('relu'))
model.add(Dense(64)) model.add(Dense(64))
...@@ -283,7 +285,10 @@ class DQNLearner(LearnerBase): ...@@ -283,7 +285,10 @@ class DQNLearner(LearnerBase):
return model return model
def get_default_policy(self): def get_default_policy(self):
return MaxBoltzmannQPolicy(eps=0.3) return RestrictedEpsGreedyQPolicy(0.3)
def get_default_test_policy(self):
return RestrictedGreedyQPolicy()
def get_default_memory(self): def get_default_memory(self):
"""Creates the default memory model. """Creates the default memory model.
...@@ -294,7 +299,7 @@ class DQNLearner(LearnerBase): ...@@ -294,7 +299,7 @@ class DQNLearner(LearnerBase):
limit=self.mem_size, window_length=self.mem_window_length) limit=self.mem_size, window_length=self.mem_window_length)
return memory return memory
def create_agent(self, model, policy, memory): def create_agent(self, model, policy, memory, test_policy):
"""Creates a KerasRL DDPGAgent with given components. """Creates a KerasRL DDPGAgent with given components.
Args: Args:
...@@ -413,6 +418,86 @@ class DQNLearner(LearnerBase): ...@@ -413,6 +418,86 @@ class DQNLearner(LearnerBase):
return relevant return relevant
class RestrictedEpsGreedyQPolicy(EpsGreedyQPolicy):
"""Implement the epsilon greedy policy
Restricted Eps Greedy policy.
This policy ensures that it never chooses the action whose value is -inf
"""
def __init__(self, eps=.1):
super(RestrictedEpsGreedyQPolicy, self).__init__(eps)
def select_action(self, q_values):
"""Return the selected action
# Arguments
q_values (np.ndarray): List of the estimations of Q for each action
# Returns
Selection action
"""
assert q_values.ndim == 1
nb_actions = q_values.shape[0]
index = list()
for i in range(0, nb_actions):
if q_values[i] != -np.inf:
index.append(i)
# every q_value is -np.inf (this sometimes inevitably happens within the fit and test functions
# of kerasrl at the terminal stage as they force to call forward in Kerasrl-learner which calls this function.
# In this case, we choose a policy randomly.
if len(index) < 1:
action = np.random.random_integers(0, nb_actions - 1)
elif np.random.uniform() <= self.eps:
action = index[np.random.random_integers(0, len(index) - 1)]
else:
action = np.argmax(q_values)
return action
class RestrictedGreedyQPolicy(GreedyQPolicy):
"""Implement the epsilon greedy policy
Restricted Greedy policy.
This policy ensures that it never chooses the action whose value is -inf
"""
def select_action(self, q_values):
"""Return the selected action
# Arguments
q_values (np.ndarray): List of the estimations of Q for each action
# Returns
Selection action
"""
assert q_values.ndim == 1
nb_actions = q_values.shape[0]
restricted_q_values = list()
for i in range(0, nb_actions):
if q_values[i] != -np.inf:
restricted_q_values.append(q_values[i])
# every q_value is -np.inf (this sometimes inevitably happens within the fit and test functions
# of kerasrl at the terminal stage as they force to call forward in Kerasrl-learner which calls this function.
# In this case, we choose a policy randomly.
if len(restricted_q_values) < 1:
action = np.random.random_integers(0, nb_actions - 1)
else:
action = np.argmax(restricted_q_values)
return action
class DQNAgentOverOptions(DQNAgent): class DQNAgentOverOptions(DQNAgent):
def __init__(self, def __init__(self,
model, model,
......
...@@ -25,7 +25,8 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -25,7 +25,8 @@ class EpisodicEnvBase(GymCompliantEnvBase):
# three types possible ('min', 'max', or 'sum'); # three types possible ('min', 'max', or 'sum');
# See _reward_superposition below. # See _reward_superposition below.
terminal_reward_type = 'max' # TODO: consider the case, where every terminal reward is None. Make this class have a default terminal value (not None) and use it in this case.
terminal_reward_type = 'min'
#: If true, the maneuver terminates when the goal has been achieved. #: If true, the maneuver terminates when the goal has been achieved.
_terminate_in_goal = False _terminate_in_goal = False
...@@ -140,13 +141,11 @@ class EpisodicEnvBase(GymCompliantEnvBase): ...@@ -140,13 +141,11 @@ class EpisodicEnvBase(GymCompliantEnvBase):
def _reset_model_checker(self, AP): def _reset_model_checker(self, AP):
self.__mc_AP = int(AP)
if self._LTL_preconditions_enable: if self._LTL_preconditions_enable:
for LTL_precondition in self._LTL_preconditions: for LTL_precondition in self._LTL_preconditions:
LTL_precondition.reset_property() LTL_precondition.reset_property()
if LTL_precondition.enabled:
LTL_precondition.check_incremental(self.__mc_AP) self._incremental_model_checking(AP)
def _set_mc_AP(self, AP): def _set_mc_AP(self, AP):
self.__mc_AP = int(AP) self.__mc_AP = int(AP)
......
...@@ -271,7 +271,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase): ...@@ -271,7 +271,7 @@ class SimpleIntersectionEnv(RoadEnv, EpisodicEnvBase):
# stopped_car_scenario = bool(np.random.randint(0, 1)) TODO: this scenario may not work # stopped_car_scenario = bool(np.random.randint(0, 1)) TODO: this scenario may not work
n_others_stopped_in_stop_region = np.random.randint( n_others_stopped_in_stop_region = np.random.randint(
0, min(3, n_others - stopped_car_scenario)) 0, min(3, n_others - stopped_car_scenario))
veh_ahead_scenario = bool(np.random.randint(0, 1)) veh_ahead_scenario = bool(np.random.randint(0, 1)) or veh_ahead_scenario
if n_others_stopped_in_stop_region > min( if n_others_stopped_in_stop_region > min(
n_others - stopped_car_scenario, 3): n_others - stopped_car_scenario, 3):
......
...@@ -332,8 +332,7 @@ class ManeuverBase(EpisodicEnvBase): ...@@ -332,8 +332,7 @@ class ManeuverBase(EpisodicEnvBase):
Returns True if the condition is satisfied, and False otherwise. Returns True if the condition is satisfied, and False otherwise.
""" """
return not (self.env.termination_condition or self.violation_happened) and \ return not self.termination_condition and self.extra_initiation_condition
self.extra_initiation_condition
@property @property
def extra_initiation_condition(self): def extra_initiation_condition(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment