From 61a07254234fc12eeda5571b0c760c0b18c0e87c Mon Sep 17 00:00:00 2001
From: sacardoz <sacardoz@uwaterloo.ca>
Date: Mon, 27 Nov 2023 18:56:26 -0500
Subject: [PATCH] Occlusion Manager

---
 .../builders/occlusion_manager_builder.py     | 18 +++++++
 .../script/builders/simulation_builder.py     |  9 ++++
 .../simulation/occlusion/occlusion_manager.py | 54 +++++++++++++++++++
 nuplan/planning/simulation/simulation.py      |  8 ++-
 .../planning/simulation/simulation_setup.py   |  7 ++-
 5 files changed, 94 insertions(+), 2 deletions(-)
 create mode 100644 nuplan/planning/script/builders/occlusion_manager_builder.py
 create mode 100644 nuplan/planning/simulation/occlusion/occlusion_manager.py

diff --git a/nuplan/planning/script/builders/occlusion_manager_builder.py b/nuplan/planning/script/builders/occlusion_manager_builder.py
new file mode 100644
index 0000000..d496b49
--- /dev/null
+++ b/nuplan/planning/script/builders/occlusion_manager_builder.py
@@ -0,0 +1,18 @@
+
+from omegaconf import DictConfig
+
+from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario
+from nuplan.planning.simulation.occlusion.occlusion_manager import AbstractOcclusionManager
+from nuplan.planning.simulation.occlusion.range_occlusion_manager import RangeOcclusionManager
+
+def build_occlusion_manager(occlusion_cfg: DictConfig, scenario: AbstractScenario) -> AbstractOcclusionManager:
+    """
+    Instantiate occlusion_manager
+    :param occlusion_cfg: config of a occlusion_manager
+    :param scenario: scenario
+    :return occlusion_cfg
+    """
+    # Placeholder
+    occlusion_manager: AbstractOcclusionManager = RangeOcclusionManager(scenario)
+
+    return occlusion_manager
diff --git a/nuplan/planning/script/builders/simulation_builder.py b/nuplan/planning/script/builders/simulation_builder.py
index 1d35ab7..105fa67 100644
--- a/nuplan/planning/script/builders/simulation_builder.py
+++ b/nuplan/planning/script/builders/simulation_builder.py
@@ -9,6 +9,7 @@ from nuplan.common.utils.distributed_scenario_filter import DistributedMode, Dis
 from nuplan.planning.scenario_builder.nuplan_db.nuplan_scenario_builder import NuPlanScenarioBuilder
 from nuplan.planning.script.builders.metric_builder import build_metrics_engines
 from nuplan.planning.script.builders.observation_builder import build_observations
+from nuplan.planning.script.builders.occlusion_manager_builder import build_occlusion_manager
 from nuplan.planning.script.builders.planner_builder import build_planners
 from nuplan.planning.script.builders.utils.utils_type import is_target_type
 from nuplan.planning.simulation.callback.abstract_callback import AbstractCallback
@@ -16,6 +17,7 @@ from nuplan.planning.simulation.callback.metric_callback import MetricCallback
 from nuplan.planning.simulation.callback.multi_callback import MultiCallback
 from nuplan.planning.simulation.controller.abstract_controller import AbstractEgoController
 from nuplan.planning.simulation.observation.abstract_observation import AbstractObservation
+from nuplan.planning.simulation.occlusion.occlusion_manager import AbstractOcclusionManager
 from nuplan.planning.simulation.planner.abstract_planner import AbstractPlanner
 from nuplan.planning.simulation.runner.simulations_runner import SimulationRunner
 from nuplan.planning.simulation.simulation import Simulation
@@ -103,6 +105,12 @@ def build_simulations(
             # Perception
             observations: AbstractObservation = build_observations(cfg.observation, scenario=scenario)
 
+            # Occlusions
+            if 'occlusion' in cfg.keys() and cfg.occlusion:
+                occlusion_manager: AbstractOcclusionManager = build_occlusion_manager(cfg.occlusion, scenario=scenario)
+            else:
+                occlusion_manager = None
+
             # Metric Engine
             metric_engine = metric_engines_map.get(scenario.scenario_type, None)
             if metric_engine is not None:
@@ -120,6 +128,7 @@ def build_simulations(
                 time_controller=simulation_time_controller,
                 observations=observations,
                 ego_controller=ego_controller,
+                occlusion_manager=occlusion_manager,
                 scenario=scenario,
             )
 
diff --git a/nuplan/planning/simulation/occlusion/occlusion_manager.py b/nuplan/planning/simulation/occlusion/occlusion_manager.py
new file mode 100644
index 0000000..32ebb33
--- /dev/null
+++ b/nuplan/planning/simulation/occlusion/occlusion_manager.py
@@ -0,0 +1,54 @@
+from abc import ABCMeta, abstractmethod
+from collections import deque
+from typing import Tuple
+
+from nuplan.common.actor_state.ego_state import EgoState
+from nuplan.common.actor_state.tracked_objects import TrackedObjects
+
+from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario
+from nuplan.planning.simulation.history.simulation_history_buffer import SimulationHistoryBuffer
+from nuplan.planning.simulation.observation.observation_type import DetectionsTracks, Observation
+
+
+class AbstractOcclusionManager(metaclass=ABCMeta):
+    def __init__(
+        self,
+        scenario: AbstractScenario
+    ):
+        self._masks = {}
+        self.scenario = scenario
+
+    def reset(self) -> None:
+        self._masks = {}
+
+    def occlude_input(self, input_buffer: SimulationHistoryBuffer) -> SimulationHistoryBuffer:
+        ego_state_buffer = input_buffer.ego_state_buffer
+        observations_buffer = input_buffer.observation_buffer
+        sample_interval = input_buffer.sample_interval
+
+        for ego_state, observations in zip(ego_state_buffer, observations_buffer):
+            if ego_state.time_us not in self._masks:
+                self._masks[ego_state.time_us] = self._compute_mask(ego_state, observations)
+                
+        output_buffer = SimulationHistoryBuffer(ego_state_buffer, \
+                            deque([self._mask_input(ego_state.time_us, observations) for ego_state, observations in zip(ego_state_buffer, observations_buffer)]), \
+                                sample_interval)
+
+        return output_buffer
+    
+    @abstractmethod
+    def _compute_mask(self, ego_state: EgoState, observations: DetectionsTracks) -> set:
+        pass
+
+    def _mask_input(self, time_us: int, observations: DetectionsTracks) -> DetectionsTracks:
+        assert time_us in self._masks, "Attempted to mask non-cached timestep!"
+        assert isinstance(observations, DetectionsTracks), "Occlusions only support DetectionsTracks."
+
+        mask = self._masks[time_us]
+        tracks = observations.tracked_objects.tracked_objects
+
+        visible_tracks = [track for track in tracks if track.metadata.track_token in mask]
+
+        return DetectionsTracks(tracked_objects=TrackedObjects(visible_tracks))
+
+
diff --git a/nuplan/planning/simulation/simulation.py b/nuplan/planning/simulation/simulation.py
index 243d8d3..9e61f10 100644
--- a/nuplan/planning/simulation/simulation.py
+++ b/nuplan/planning/simulation/simulation.py
@@ -45,6 +45,7 @@ class Simulation:
         self._time_controller = simulation_setup.time_controller
         self._ego_controller = simulation_setup.ego_controller
         self._observations = simulation_setup.observations
+        self._occlusion_manager = simulation_setup.occlusion_manager
         self._scenario = simulation_setup.scenario
         self._callback = MultiCallback([]) if callback is None else callback
 
@@ -136,8 +137,13 @@ class Simulation:
 
         # Extract traffic light status data
         traffic_light_data = list(self._scenario.get_traffic_light_status_at_iteration(iteration.index))
+
+        history_input = self._history_buffer
+        if self._occlusion_manager is not None:
+            history_input = self._occlusion_manager.occlude_input(history_input)
+
         logger.debug(f"Executing {iteration.index}!")
-        return PlannerInput(iteration=iteration, history=self._history_buffer, traffic_light_data=traffic_light_data)
+        return PlannerInput(iteration=iteration, history=history_input, traffic_light_data=traffic_light_data)
 
     def propagate(self, trajectory: AbstractTrajectory) -> None:
         """
diff --git a/nuplan/planning/simulation/simulation_setup.py b/nuplan/planning/simulation/simulation_setup.py
index 34da9d3..944d06d 100644
--- a/nuplan/planning/simulation/simulation_setup.py
+++ b/nuplan/planning/simulation/simulation_setup.py
@@ -3,6 +3,7 @@ from dataclasses import dataclass
 from nuplan.planning.scenario_builder.abstract_scenario import AbstractScenario
 from nuplan.planning.simulation.controller.abstract_controller import AbstractEgoController
 from nuplan.planning.simulation.observation.abstract_observation import AbstractObservation
+from nuplan.planning.simulation.occlusion.occlusion_manager import AbstractOcclusionManager
 from nuplan.planning.simulation.planner.abstract_planner import AbstractPlanner
 from nuplan.planning.simulation.simulation_time_controller.abstract_simulation_time_controller import (
     AbstractSimulationTimeController,
@@ -16,6 +17,7 @@ class SimulationSetup:
     time_controller: AbstractSimulationTimeController
     observations: AbstractObservation
     ego_controller: AbstractEgoController
+    occlusion_manager: AbstractOcclusionManager
     scenario: AbstractScenario
 
     def __post_init__(self) -> None:
@@ -30,7 +32,7 @@ class SimulationSetup:
         assert isinstance(
             self.ego_controller, AbstractEgoController
         ), 'Error: ego_controller must inherit from AbstractEgoController!'
-
+        
     def reset(self) -> None:
         """
         Reset all simulation controllers
@@ -39,6 +41,9 @@ class SimulationSetup:
         self.ego_controller.reset()
         self.time_controller.reset()
 
+        if self.occlusion_manager:
+            self.occlusion_manager.reset()
+
 
 def validate_planner_setup(setup: SimulationSetup, planner: AbstractPlanner) -> None:
     """
-- 
GitLab