ppo2_training.py 1.37 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.baselines_learner import PPO2Agent

if __name__ == "__main__":

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    options = OptionsGraph("config.json", SimpleIntersectionEnv)
    options.set_current_node("changeright")
    options.current_node.reset()

    # Use this code when you train a specific maneuver for the first time.
Ashish Gaurav's avatar
Ashish Gaurav committed
17 18 19 20 21
    agent = PPO2Agent(
        input_shape=(options.current_node.get_reduced_feature_length(), ),
        nb_actions=2,
        env=options.current_node,
        tensorboard=True)
Aravind Bk's avatar
Aravind Bk committed
22 23 24 25 26 27

    # Use this code to resume the training from the last step.
    # agent.load_weights("right_weights.h5f")

    # Train the NN models and save them.
    agent.fit(nb_max_episode_steps=200, nb_steps=50000)
Ashish Gaurav's avatar
Ashish Gaurav committed
28 29 30
    agent.save_weights(
        "right_weights.h5f"
    )  # Save the NN weights for reloading them in the future.
Aravind Bk's avatar
Aravind Bk committed
31

Ashish Gaurav's avatar
Ashish Gaurav committed
32 33
    # Uncomment this after training to use trained model.
    # Comment this to use manually defined policy
Aravind Bk's avatar
Aravind Bk committed
34 35
    #agent.load_weights("right_weights.h5f")

Ashish Gaurav's avatar
Ashish Gaurav committed
36
    print("Testing model...")
Aravind Bk's avatar
Aravind Bk committed
37 38 39

    # Test trained maneuver
    # TODO: the graphical window is not closed before completing the test.
Ashish Gaurav's avatar
Ashish Gaurav committed
40
    agent.test_model(options.current_node, nb_episodes=50)