Commit 32f95d60 authored by xavier's avatar xavier

clean up

parent e0c6493c
......@@ -4,7 +4,7 @@ import pandas as pd
class rlalgorithm:
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.01, lam=0):
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.01, lam=0.9):
self.actions = actions
self.lam = lam
self.lr = learning_rate
......@@ -51,7 +51,11 @@ class rlalgorithm:
if s_ != 'terminal':
self.expectation *= self.gamma * self.lam
else:
self.expectation = pd.DataFrame(columns=self.actions, dtype=np.float64) # clear it once episode is done.
# clear it once episode is done.
self.expectation = pd.DataFrame(0,
columns=self.q_table.columns,
index=self.q_table.index)
return s_, None
return s_, a_
'''States are dynamically added to the Q(S,A) table as they are encountered'''
......
......@@ -3,6 +3,7 @@ from RL_brainsample_PI import rlalgorithm as rlalg1
from RL_reinforce import rlalgorithm as rlalg_policy_grad
from RL_expected_sarsa import rlalgorithm as rlalg_expected_sarsa
from RL_double_q_learning import rlalgorithm as rlalg_double_q
from RL_sarsa_lambda import rlalgorithm as rlalg_sarsa_lambda
import numpy as np
import sys
......@@ -122,12 +123,22 @@ if __name__ == "__main__":
pits=[]
# Task 3
wall_shape=np.array([[6,3],[6,3],[6,2],[5,2],[4,2],[3,2],[3,3],
[3,4],[3,5],[3,6],[4,6],[5,6],[5,7],[7,3]])
pits=np.array([[1,3],[0,5], [7,7], [8,5]])
#wall_shape=np.array([[6,3],[6,3],[6,2],[5,2],[4,2],[3,2],[3,3],
# [3,4],[3,5],[3,6],[4,6],[5,6],[5,7],[7,3]])
#pits=np.array([[1,3],[0,5], [7,7], [8,5]])
experiments = []#[(env1,RL1, data1)]
########## Sarsa Lambda #########################
start_time = time.time()
env5 = Maze(agentXY,goalXY,wall_shape,pits)
RL5 = rlalg_sarsa_lambda(actions=list(range(env5.n_actions)),lam=0.9)
data5={}
env5.after(1, update(env5, RL5, data5, 2000))
env5.mainloop()
experiments.append((env5,RL5, data5))
print("Elapsed time for Sarsa Lambda is")
elapsed_time = time.time() - start_time
print(elapsed_time)
########## Double Q-learning ######################
start_time = time.time()
......@@ -173,7 +184,7 @@ if __name__ == "__main__":
if(do_plot_rewards):
#Simple plot of return for each episode and algorithm, you can make more informative plots
plot_rewards(experiments)
plot_rewards([(env2,RL2, data2)])
#plot_rewards([(env2,RL2, data2)])
plot_rewards([(env3,RL3, data3)])#expected sarsa
plot_rewards([(env4,RL4, data4)])#double q learning
plot_rewards([(env5,RL5, data5)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment