ppo2_training.py 1.33 KB
Newer Older
Aravind Bk's avatar
Aravind Bk committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
from env.simple_intersection import SimpleIntersectionEnv
from env.simple_intersection.constants import *
from options.options_loader import OptionsGraph
from backends.baselines_learner import PPO2Agent

if __name__ == "__main__":

    # initialize the numpy random number generator
    np.random.seed()

    # load options graph
    options = OptionsGraph("config.json", SimpleIntersectionEnv)
    options.set_current_node("changeright")
    options.current_node.reset()

    # Use this code when you train a specific maneuver for the first time.
    agent = PPO2Agent(input_shape=(options.current_node.get_reduced_feature_length(),),
                 nb_actions=2, env=options.current_node, tensorboard=True)

    # Use this code to resume the training from the last step.
    # agent.load_weights("right_weights.h5f")

    # Train the NN models and save them.
    agent.fit(nb_max_episode_steps=200, nb_steps=50000)
    agent.save_weights("right_weights.h5f")  # Save the NN weights for reloading them in the future.

    # Uncomment this after training to use trained model. Comment this to use manually defined policy
    #agent.load_weights("right_weights.h5f")

    print ("Testing model...")

    # Test trained maneuver
    # TODO: the graphical window is not closed before completing the test.
    agent.test_model(options.current_node, nb_episodes=50)