From fc542bfed1442487909ec403c27fa40558c20988 Mon Sep 17 00:00:00 2001
From: sacardoz <sacardoz@uwaterloo.ca>
Date: Sat, 23 Dec 2023 21:14:28 -0500
Subject: [PATCH] idm works mashallah

---
 experiments/test_notebook.ipynb               | 138 +++++------
 .../observation/ml_planner_agents.py          | 215 +++++++++++++++---
 2 files changed, 255 insertions(+), 98 deletions(-)

diff --git a/experiments/test_notebook.ipynb b/experiments/test_notebook.ipynb
index abdac83..1b183ce 100644
--- a/experiments/test_notebook.ipynb
+++ b/experiments/test_notebook.ipynb
@@ -26,7 +26,7 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "/tmp/ipykernel_187055/4095267831.py:5: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
+      "/tmp/ipykernel_235422/4095267831.py:5: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display\n",
       "  from IPython.core.display import display, HTML\n"
      ]
     }
@@ -173,7 +173,7 @@
     "    f\"planner.pdm_hybrid_planner.checkpoint_path={hybrid_ckpt}\" ,\n",
     "    'observation.model_config=${model}',\n",
     "    f'observation.checkpoint_path={ckpt_dir}',\n",
-    "    f'observation.planner_type=idm',\n",
+    "    f'observation.planner_type=pdm_closed',\n",
     "    f'observation.pdm_hybrid_ckpt={hybrid_ckpt}',\n",
     "    f'observation.occlusions=true',\n",
     "    'worker=sequential',\n",
@@ -205,34 +205,34 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "2023-12-22 02:28:15,284 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/worker_pool_builder.py:19}  Building WorkerPool...\n",
-      "2023-12-22 02:28:15,286 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/utils/multithreading/worker_pool.py:101}  Worker: Sequential\n",
-      "2023-12-22 02:28:15,286 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/utils/multithreading/worker_pool.py:102}  Number of nodes: 1\n",
+      "2023-12-23 20:52:26,435 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/worker_pool_builder.py:19}  Building WorkerPool...\n",
+      "2023-12-23 20:52:26,437 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/utils/multithreading/worker_pool.py:101}  Worker: Sequential\n",
+      "2023-12-23 20:52:26,437 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/utils/multithreading/worker_pool.py:102}  Number of nodes: 1\n",
       "Number of CPUs per node: 1\n",
       "Number of GPUs per node: 0\n",
       "Number of threads across all nodes: 1\n",
-      "2023-12-22 02:28:15,286 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/worker_pool_builder.py:27}  Building WorkerPool...DONE!\n",
-      "2023-12-22 02:28:15,286 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/folder_builder.py:32}  Building experiment folders...\n",
-      "2023-12-22 02:28:15,286 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/folder_builder.py:35}  \n",
+      "2023-12-23 20:52:26,437 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/worker_pool_builder.py:27}  Building WorkerPool...DONE!\n",
+      "2023-12-23 20:52:26,437 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/folder_builder.py:32}  Building experiment folders...\n",
+      "2023-12-23 20:52:26,437 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/folder_builder.py:35}  \n",
       "\n",
-      "\tFolder where all results are stored: /media/sacardoz/Storage/nuplan/exp/exp/simulation/closed_loop_multiagent/2023.12.22.02.28.14\n",
+      "\tFolder where all results are stored: /media/sacardoz/Storage/nuplan/exp/exp/simulation/closed_loop_multiagent/2023.12.23.20.52.26\n",
       "\n",
-      "2023-12-22 02:28:15,288 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/folder_builder.py:70}  Building experiment folders...DONE!\n",
-      "2023-12-22 02:28:15,288 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_callback_builder.py:52}  Building AbstractCallback...\n",
-      "2023-12-22 02:28:15,288 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_callback_builder.py:68}  Building AbstractCallback: 0...DONE!\n",
-      "2023-12-22 02:28:15,288 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:49}  Building simulations...\n",
-      "2023-12-22 02:28:15,288 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:55}  Extracting scenarios...\n",
-      "2023-12-22 02:28:15,288 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/common/utils/distributed_scenario_filter.py:83}  Building Scenarios in mode DistributedMode.SINGLE_NODE\n",
-      "2023-12-22 02:28:15,289 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/scenario_building_builder.py:18}  Building AbstractScenarioBuilder...\n",
-      "2023-12-22 02:28:15,302 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/scenario_building_builder.py:21}  Building AbstractScenarioBuilder...DONE!\n",
-      "2023-12-22 02:28:15,302 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/scenario_filter_builder.py:35}  Building ScenarioFilter...\n",
-      "2023-12-22 02:28:15,303 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/scenario_filter_builder.py:44}  Building ScenarioFilter...DONE!\n",
-      "2023-12-22 02:28:15,322 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:76}  Building metric engines...\n",
-      "2023-12-22 02:28:15,352 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:78}  Building metric engines...DONE\n",
-      "2023-12-22 02:28:15,352 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:82}  Building simulations from 1 scenarios...\n",
-      "2023-12-22 02:28:15,506 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/model_builder.py:18}  Building TorchModuleWrapper...\n",
-      "2023-12-22 02:28:15,547 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/model_builder.py:21}  Building TorchModuleWrapper...DONE!\n",
-      "2023-12-22 02:28:16,201 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:142}  Building simulations...DONE!\n"
+      "2023-12-23 20:52:26,438 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/folder_builder.py:70}  Building experiment folders...DONE!\n",
+      "2023-12-23 20:52:26,439 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_callback_builder.py:52}  Building AbstractCallback...\n",
+      "2023-12-23 20:52:26,439 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_callback_builder.py:68}  Building AbstractCallback: 0...DONE!\n",
+      "2023-12-23 20:52:26,439 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:49}  Building simulations...\n",
+      "2023-12-23 20:52:26,439 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:55}  Extracting scenarios...\n",
+      "2023-12-23 20:52:26,439 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/common/utils/distributed_scenario_filter.py:83}  Building Scenarios in mode DistributedMode.SINGLE_NODE\n",
+      "2023-12-23 20:52:26,439 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/scenario_building_builder.py:18}  Building AbstractScenarioBuilder...\n",
+      "2023-12-23 20:52:26,452 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/scenario_building_builder.py:21}  Building AbstractScenarioBuilder...DONE!\n",
+      "2023-12-23 20:52:26,452 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/scenario_filter_builder.py:35}  Building ScenarioFilter...\n",
+      "2023-12-23 20:52:26,453 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/scenario_filter_builder.py:44}  Building ScenarioFilter...DONE!\n",
+      "2023-12-23 20:52:26,470 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:76}  Building metric engines...\n",
+      "2023-12-23 20:52:26,498 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:78}  Building metric engines...DONE\n",
+      "2023-12-23 20:52:26,498 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:82}  Building simulations from 1 scenarios...\n",
+      "2023-12-23 20:52:26,633 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/model_builder.py:18}  Building TorchModuleWrapper...\n",
+      "2023-12-23 20:52:26,669 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/model_builder.py:21}  Building TorchModuleWrapper...DONE!\n",
+      "2023-12-23 20:52:27,238 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/script/builders/simulation_builder.py:142}  Building simulations...DONE!\n"
      ]
     }
    ],
@@ -286,28 +286,6 @@
      "output_type": "stream",
      "text": [
       "SimulationIteration(time_point=TimePoint(time_us=1623707846350127), index=0)\n",
-      "2023-12-22 02:28:27,237 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:28,208 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:29,094 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:30,074 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:31,080 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:32,985 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:34,668 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:35,448 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:36,250 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:38,638 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:43,540 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:44,367 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:51,440 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:52,331 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:53,206 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:54,967 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:57,304 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:58,116 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:58,881 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:28:59,765 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:29:00,661 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
-      "2023-12-22 02:29:01,327 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707846450055), index=1)\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707846549980), index=2)\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707846649908), index=3)\n",
@@ -318,7 +296,6 @@
       "SimulationIteration(time_point=TimePoint(time_us=1623707847149533), index=8)\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707847249494), index=9)\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707847349489), index=10)\n",
-      "2023-12-22 02:29:31,281 WARNING {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/simulation/planner/idm_planner.py:153}  IDMPlanner could not find valid path to the target roadblock. Using longest route found instead\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707847449511), index=11)\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707847549545), index=12)\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707847649595), index=13)\n",
@@ -431,7 +408,32 @@
       "SimulationIteration(time_point=TimePoint(time_us=1623707858350135), index=120)\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707858450145), index=121)\n",
       "SimulationIteration(time_point=TimePoint(time_us=1623707858550149), index=122)\n",
-      "SimulationIteration(time_point=TimePoint(time_us=1623707858650146), index=123)\n"
+      "SimulationIteration(time_point=TimePoint(time_us=1623707858650146), index=123)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707858750137), index=124)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707858850126), index=125)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707858950113), index=126)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859050096), index=127)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859150076), index=128)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859250064), index=129)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859350064), index=130)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859450064), index=131)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859550063), index=132)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859650055), index=133)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859750047), index=134)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859850041), index=135)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707859950028), index=136)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860050018), index=137)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860150002), index=138)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860249979), index=139)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860349963), index=140)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860449960), index=141)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860549964), index=142)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860649965), index=143)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860749961), index=144)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860849946), index=145)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707860949927), index=146)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707861049906), index=147)\n",
+      "SimulationIteration(time_point=TimePoint(time_us=1623707861149865), index=148)\n"
      ]
     }
    ],
@@ -452,7 +454,7 @@
     "    # Perform step\n",
     "    planner_input = runner._simulation.get_planner_input()\n",
     "\n",
-    "    # Execute specific callback\n",
+    "    # Execute specific call+back\n",
     "    runner._simulation.callback.on_planner_start(runner.simulation.setup, runner.planner)\n",
     "\n",
     "    # Plan path based on all planner's inputs\n",
@@ -496,7 +498,7 @@
        "  (function() {\n",
        "    const xhr = new XMLHttpRequest()\n",
        "    xhr.responseType = 'blob';\n",
-       "    xhr.open('GET', \"http://localhost:5000/autoload.js?bokeh-autoload-element=1003&bokeh-absolute-url=http://localhost:5000&resources=none\", true);\n",
+       "    xhr.open('GET', \"http://localhost:5009/autoload.js?bokeh-autoload-element=1003&bokeh-absolute-url=http://localhost:5009&resources=none\", true);\n",
        "    xhr.onload = function (event) {\n",
        "      const script = document.createElement('script');\n",
        "      const src = URL.createObjectURL(event.target.response);\n",
@@ -510,7 +512,7 @@
      },
      "metadata": {
       "application/vnd.bokehjs_exec.v0+json": {
-       "server_id": "15b154fce756424d9ca7e3de9f9760be"
+       "server_id": "7d9bc83475124657aff7017705e9f4df"
       }
      },
      "output_type": "display_data"
@@ -519,26 +521,26 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "2023-12-22 01:52:05,312 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Error creating dataset. Could not read schema from 'run_sim_closed_loop/training_raster_experiment/train_default_raster/2023.11.14.22.55.23/hparams.yaml'. Is this a 'parquet' file?: Could not open Parquet input source 'run_sim_closed_loop/training_raster_experiment/train_default_raster/2023.11.14.22.55.23/hparams.yaml': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
-      "2023-12-22 01:52:05,313 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
-      "2023-12-22 01:52:05,313 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
-      "2023-12-22 01:52:05,314 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Error creating dataset. Could not read schema from 'pretrained_checkpoints/gc_pgp_checkpoint.ckpt'. Is this a 'parquet' file?: Could not open Parquet input source 'pretrained_checkpoints/gc_pgp_checkpoint.ckpt': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
-      "2023-12-22 01:52:05,314 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
-      "2023-12-22 01:52:05,314 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
-      "2023-12-22 01:52:05,315 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
-      "2023-12-22 01:52:05,315 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
-      "2023-12-22 01:52:05,318 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/simulation_tile.py:172}  Minimum frame time=0.017 s\n"
+      "2023-12-23 20:59:59,909 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Error creating dataset. Could not read schema from 'run_sim_closed_loop/training_raster_experiment/train_default_raster/2023.11.14.22.55.23/hparams.yaml'. Is this a 'parquet' file?: Could not open Parquet input source 'run_sim_closed_loop/training_raster_experiment/train_default_raster/2023.11.14.22.55.23/hparams.yaml': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
+      "2023-12-23 20:59:59,910 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
+      "2023-12-23 20:59:59,910 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
+      "2023-12-23 20:59:59,911 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Error creating dataset. Could not read schema from 'pretrained_checkpoints/gc_pgp_checkpoint.ckpt'. Is this a 'parquet' file?: Could not open Parquet input source 'pretrained_checkpoints/gc_pgp_checkpoint.ckpt': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
+      "2023-12-23 20:59:59,911 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
+      "2023-12-23 20:59:59,911 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
+      "2023-12-23 20:59:59,911 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
+      "2023-12-23 20:59:59,912 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/experiment_file_data.py:140}  Could not open Parquet input source '<Buffer>': Parquet magic bytes not found in footer. Either the file is corrupted or this is not a parquet file.\n",
+      "2023-12-23 20:59:59,914 INFO {/media/sacardoz/Storage/nuplan-devkit/nuplan/planning/nuboard/base/simulation_tile.py:172}  Minimum frame time=0.017 s\n"
      ]
     },
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Rendering a scenario: 100%|██████████| 1/1 [00:00<00:00, 60.29it/s]\n",
+      "Rendering a scenario: 100%|██████████| 1/1 [00:00<00:00, 30.44it/s]\n",
       "WARNING:bokeh.core.validation.check:W-1000 (MISSING_RENDERERS): Plot has no renderers: Figure(id='1005', ...)\n",
-      "INFO:tornado.access:200 GET /autoload.js?bokeh-autoload-element=1003&bokeh-absolute-url=http://localhost:5000&resources=none (::1) 846.47ms\n",
+      "INFO:tornado.access:200 GET /autoload.js?bokeh-autoload-element=1003&bokeh-absolute-url=http://localhost:5009&resources=none (::1) 1063.26ms\n",
+      "INFO:tornado.access:101 GET /ws?id=cd802e51-de22-41c8-bc8e-99ef2dff2bff&origin=da2f890a-eb18-4637-9199-dd0f06169aef&swVersion=4&extensionId=&platform=electron&vscode-resource-base-authority=vscode-resource.vscode-cdn.net&parentOrigin=vscode-file%3A%2F%2Fvscode-app&purpose=notebookRenderer (::1) 0.62ms\n",
       "INFO:bokeh.server.views.ws:WebSocket connection opened\n",
-      "INFO:tornado.access:101 GET /ws?id=935f7e47-d5c0-410b-9cdd-a988d9f0ac32&origin=da2f890a-eb18-4637-9199-dd0f06169aef&swVersion=4&extensionId=&platform=electron&vscode-resource-base-authority=vscode-resource.vscode-cdn.net&parentOrigin=vscode-file%3A%2F%2Fvscode-app&purpose=notebookRenderer (::1) 0.64ms\n",
       "INFO:bokeh.server.views.ws:ServerConnection created\n"
      ]
     },
@@ -546,15 +548,15 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "2023-12-22 01:52:06,143 INFO {/home/sacardoz/miniconda3/envs/nuplan/lib/python3.9/site-packages/tornado/web.py:2344}  200 GET /autoload.js?bokeh-autoload-element=1003&bokeh-absolute-url=http://localhost:5000&resources=none (::1) 846.47ms\n",
-      "2023-12-22 01:52:06,157 INFO {/media/sacardoz/Storage/nuplan-devkit/tutorials/utils/tutorial_utils.py:267}  Done rendering!\n",
-      "2023-12-22 01:52:06,158 INFO {/home/sacardoz/miniconda3/envs/nuplan/lib/python3.9/site-packages/tornado/web.py:2344}  101 GET /ws?id=935f7e47-d5c0-410b-9cdd-a988d9f0ac32&origin=da2f890a-eb18-4637-9199-dd0f06169aef&swVersion=4&extensionId=&platform=electron&vscode-resource-base-authority=vscode-resource.vscode-cdn.net&parentOrigin=vscode-file%3A%2F%2Fvscode-app&purpose=notebookRenderer (::1) 0.64ms\n"
+      "2023-12-23 21:00:00,955 INFO {/home/sacardoz/miniconda3/envs/nuplan/lib/python3.9/site-packages/tornado/web.py:2344}  200 GET /autoload.js?bokeh-autoload-element=1003&bokeh-absolute-url=http://localhost:5009&resources=none (::1) 1063.26ms\n",
+      "2023-12-23 21:00:00,986 INFO {/media/sacardoz/Storage/nuplan-devkit/tutorials/utils/tutorial_utils.py:267}  Done rendering!\n",
+      "2023-12-23 21:00:00,987 INFO {/home/sacardoz/miniconda3/envs/nuplan/lib/python3.9/site-packages/tornado/web.py:2344}  101 GET /ws?id=cd802e51-de22-41c8-bc8e-99ef2dff2bff&origin=da2f890a-eb18-4637-9199-dd0f06169aef&swVersion=4&extensionId=&platform=electron&vscode-resource-base-authority=vscode-resource.vscode-cdn.net&parentOrigin=vscode-file%3A%2F%2Fvscode-app&purpose=notebookRenderer (::1) 0.62ms\n"
      ]
     }
    ],
    "source": [
     "from tutorials.utils.tutorial_utils import visualize_history\n",
-    "visualize_history(runner.simulation._history, runner.scenario, bokeh_port=5000)"
+    "visualize_history(runner.simulation._history, runner.scenario, bokeh_port=5009)"
    ]
   }
  ],
diff --git a/nuplan/planning/simulation/observation/ml_planner_agents.py b/nuplan/planning/simulation/observation/ml_planner_agents.py
index 5766585..b199913 100644
--- a/nuplan/planning/simulation/observation/ml_planner_agents.py
+++ b/nuplan/planning/simulation/observation/ml_planner_agents.py
@@ -1,6 +1,8 @@
 from collections import deque
 from copy import deepcopy
-from typing import Dict, List, Type
+from typing import Dict, List, Optional, Tuple, Type
+
+import numpy as np
 
 from nuplan.common.actor_state.agent import Agent
 from nuplan.common.actor_state.ego_state import EgoState
@@ -9,6 +11,8 @@ from nuplan.common.actor_state.state_representation import StateSE2, StateVector
 from nuplan.common.actor_state.tracked_objects import TrackedObject, TrackedObjects
 from nuplan.common.actor_state.tracked_objects_types import TrackedObjectType
 from nuplan.common.actor_state.vehicle_parameters import VehicleParameters
+from nuplan.common.maps.abstract_map import AbstractMap
+from nuplan.common.maps.abstract_map_objects import LaneGraphEdgeMapObject, RoadBlockGraphEdgeMapObject
 
 from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario
 from nuplan.planning.simulation.history.simulation_history_buffer import SimulationHistoryBuffer
@@ -31,6 +35,8 @@ from tuplan_garage.planning.simulation.planner.pdm_planner.pdm_hybrid_planner im
 
 from tuplan_garage.planning.simulation.planner.pdm_planner.proposal.batch_idm_policy import BatchIDMPolicy
 
+from nuplan.common.maps.maps_datatypes import SemanticMapLayer
+
 OPEN_LOOP_DETECTION_TYPES = [TrackedObjectType.PEDESTRIAN, TrackedObjectType.BICYCLE, \
                              TrackedObjectType.CZONE_SIGN, TrackedObjectType.BARRIER, \
                              TrackedObjectType.TRAFFIC_CONE, TrackedObjectType.GENERIC_OBJECT]
@@ -93,7 +99,6 @@ class MLPlannerAgents(AbstractObservation):
         self._trajectory_cache: Dict = {}
         self._inference_frequency: float = 0.2
         self._full_inference_distance: float = 30
-        self._agent_presence_threshold: float = 10
 
     def reset(self) -> None:
         """Inherited, see superclass."""
@@ -121,19 +126,21 @@ class MLPlannerAgents(AbstractObservation):
                 # Sets agent goal to be it's last known point in the simulation. This results in some strange driving behaviour
                 # if the agent disappears early in a scene.
                 goal = self._get_historical_agent_goal(agent, self.current_iteration)
-                #print(goal)
                 if goal:
-                    # Estimates ego states from agent state at simulation starts, stores metadata and creates planner for each agent
-                    self._agents[agent.metadata.track_token] = self._build_agent_record(agent, self._scenario.start_time)
 
-                    # Initialize planner.
-                    planner_init = PlannerInitialization(
-                            route_roadblock_ids=self._scenario.get_route_roadblock_ids(),
+                    route_plan = self._get_roadblock_path(agent, goal)
+
+                    if route_plan:
+                        self._agents[agent.metadata.track_token] = self._build_agent_record(agent, self._scenario.start_time)
+
+                        # Initialize planner.
+                        planner_init = PlannerInitialization(
+                            route_roadblock_ids=route_plan,
                             mission_goal=goal,
                             map_api=self._scenario.map_api,
                         )
-                    
-                    self._agents[agent.metadata.track_token]['planner'].initialize(planner_init)
+
+                        self._agents[agent.metadata.track_token]['planner'].initialize(planner_init)
 
         return self._agents
 
@@ -312,13 +319,19 @@ class MLPlannerAgents(AbstractObservation):
 
         self._agents[agent.metadata.track_token] = self._build_agent_record(agent, timepoint_record)
 
-        planner_init = PlannerInitialization(
-                route_roadblock_ids=self._scenario.get_route_roadblock_ids(),
+        route_plan = self._get_roadblock_path(agent, goal)
+
+        if route_plan:
+            self._agents[agent.metadata.track_token] = self._build_agent_record(agent, self._scenario.start_time)
+
+            # Initialize planner.
+            planner_init = PlannerInitialization(
+                route_roadblock_ids=route_plan,
                 mission_goal=goal,
                 map_api=self._scenario.map_api,
             )
-        
-        self._agents[agent.metadata.track_token]['planner'].initialize(planner_init)
+
+            self._agents[agent.metadata.track_token]['planner'].initialize(planner_init)
 
     def _build_agent_record(self, agent: Agent, timepoint_record: TimePoint):
 
@@ -340,7 +353,7 @@ class MLPlannerAgents(AbstractObservation):
                 'occlusion': WedgeOcclusionManager(self._scenario) if self._occlusions else None}
     
     def _get_historical_agent_goal(self, agent: Agent, iteration_index: int):
-        for frame in range(self._scenario.get_number_of_iterations()-1, iteration_index+self._agent_presence_threshold, -1):
+        for frame in range(self._scenario.get_number_of_iterations()-1, iteration_index, -1):
             last_scenario_frame = self._scenario.get_tracked_objects_at_iteration(frame)
             for track in last_scenario_frame.tracked_objects.tracked_objects:
                 if track.metadata.track_token == agent.metadata.track_token:
@@ -348,20 +361,162 @@ class MLPlannerAgents(AbstractObservation):
 
         return None
     
-    def _add_newly_detected_agents(self, next_iteration: SimulationIteration):
-        for agent in self._scenario.get_tracked_objects_at_iteration(next_iteration.index).tracked_objects.get_tracked_objects_of_type(TrackedObjectType.VEHICLE):
-            if agent.metadata.track_token not in self._agents:
-                goal = self._get_historical_agent_goal(agent, next_iteration.index)
+    def _get_roadblock_path(self, agent: Agent, goal: StateSE2, max_depth: int = 10):
 
-                if goal:
-                    # Estimates ego states from agent state at simulation starts, stores metadata and creates planner for each agent
-                    self._agents[agent.metadata.track_token] = self._build_agent_record(agent, next_iteration.time_point)
+        start_edge, _ = self._get_target_segment(agent.center, self._scenario.map_api)
+        end_edge, _ = self._get_target_segment(goal, self._scenario.map_api)
 
-                    # Initialize planner.
-                    planner_init = PlannerInitialization(
-                            route_roadblock_ids=self._scenario.get_route_roadblock_ids(),
-                            mission_goal=goal,
-                            map_api=self._scenario.map_api,
-                        )
-                    
-                    self._agents[agent.metadata.track_token]['planner'].initialize(planner_init)
+        if start_edge is None:
+            return None
+        
+        if end_edge is not None:
+            gs = BreadthFirstSearch(start_edge)
+            route_plan, path_found = gs.search(end_edge, max_depth)
+        else:
+            route_plan = [start_edge]
+
+        route_plan = self._extend_path(route_plan, max_depth)        
+        route_plan = [edge.get_roadblock_id() for edge in route_plan]
+        route_plan = list(dict.fromkeys(route_plan))
+        
+        if len(route_plan) == 1:
+            route_plan = route_plan + route_plan
+
+        return route_plan
+    
+    def _extend_path(self, route_plan: List[str], min_path_length: 10):
+        """
+        Extends a route plan to a given depth by continually going forward.
+        """
+        while len(route_plan) < min_path_length:
+            outgoing_edges = route_plan[-1].outgoing_edges
+
+            if not outgoing_edges:
+                break
+
+            curvatures = [abs(edge.baseline_path.get_curvature_at_arc_length(0.0)) for edge in outgoing_edges]
+            idx = np.argmin(curvatures)
+            route_plan.append(outgoing_edges[idx])
+
+        return route_plan
+
+    def _get_target_segment(
+        self, target_state: StateSE2, map_api: AbstractMap
+    ) -> Tuple[Optional[LaneGraphEdgeMapObject], Optional[float]]:
+        """
+        Gets the map object that the agent is on and the progress along the segment.
+        :param agent: The agent of interested.
+        :param map_api: An AbstractMap instance.
+        :return: GraphEdgeMapObject and progress along the segment. If no map object is found then None.
+        """
+        if map_api.is_in_layer(target_state, SemanticMapLayer.LANE):
+            layer = SemanticMapLayer.LANE
+        elif map_api.is_in_layer(target_state, SemanticMapLayer.INTERSECTION):
+            layer = SemanticMapLayer.LANE_CONNECTOR
+        else:
+            return None, None
+
+        segments: List[LaneGraphEdgeMapObject] = map_api.get_all_map_objects(target_state, layer)
+        if not segments:
+            return None, None
+
+        # Get segment with the closest heading to the agent
+        heading_diff = [
+            segment.baseline_path.get_nearest_pose_from_position(target_state).heading - target_state.heading
+            for segment in segments
+        ]
+        closest_segment = segments[np.argmin(np.abs(heading_diff))]
+
+        progress = closest_segment.baseline_path.get_nearest_arc_length_from_position(target_state)
+        return closest_segment, progress
+
+
+class BreadthFirstSearch:
+ 
+
+    def __init__(self, start_edge: LaneGraphEdgeMapObject):
+
+        self._queue = deque([start_edge, None])
+        self._parent: Dict[str, Optional[LaneGraphEdgeMapObject]] = dict()
+        self._visited = set()
+
+    def search(
+        self, target_edge: LaneGraphEdgeMapObject, max_depth: int
+    ) -> Tuple[List[LaneGraphEdgeMapObject], bool]:
+
+        start_edge = self._queue[0]
+
+        # Initial search states
+        path_found: bool = False
+        end_edge: LaneGraphEdgeMapObject = start_edge
+        end_depth: int = 1
+        depth: int = 1
+
+        self._parent[start_edge.id + f"_{depth}"] = None
+
+        while self._queue:
+            current_edge = self._queue.popleft()
+            if current_edge is not None:
+                self._visited.add(current_edge.id)
+
+            # Early exit condition
+            if self._check_end_condition(depth, max_depth):
+                break
+
+            # Depth tracking
+            if current_edge is None:
+                depth += 1
+                self._queue.append(None)
+                if self._queue[0] is None:
+                    break
+                continue
+
+            # Goal condition
+            if self._check_goal_condition(current_edge, target_edge):
+                end_edge = current_edge
+                end_depth = depth
+                path_found = True
+                break
+
+            # Populate queue
+            for next_edge in current_edge.outgoing_edges:
+                if next_edge.id not in self._visited:
+                    self._queue.append(next_edge)
+                    self._parent[next_edge.id + f"_{depth + 1}"] = current_edge
+                    end_edge = next_edge
+                    end_depth = depth + 1
+
+        return self._construct_path(end_edge, end_depth), path_found
+
+    @staticmethod
+    def _check_end_condition(depth: int, target_depth: int) -> bool:
+        """
+        Check if the search should end regardless if the goal condition is met.
+        :param depth: The current depth to check.
+        :param target_depth: The target depth to check against.
+        :return: True if:
+            - The current depth exceeds the target depth.
+        """
+        return depth > target_depth
+
+    @staticmethod
+    def _check_goal_condition(
+        current_edge: LaneGraphEdgeMapObject,
+        target_edge: LaneGraphEdgeMapObject,
+    ) -> bool:
+        return current_edge.id == target_edge.id
+
+    def _construct_path(self, end_edge: LaneGraphEdgeMapObject, depth: int) -> List[LaneGraphEdgeMapObject]:
+        """
+        :param end_edge: The end edge to start back propagating back to the start edge.
+        :param depth: The depth of the target edge.
+        :return: The constructed path as a list of LaneGraphEdgeMapObject
+        """
+        path = [end_edge]
+        while self._parent[end_edge.id + f"_{depth}"] is not None:
+            path.append(self._parent[end_edge.id + f"_{depth}"])
+            end_edge = self._parent[end_edge.id + f"_{depth}"]
+            depth -= 1
+        path.reverse()
+
+        return path
-- 
GitLab