From 99564e184f0206d0d994ec01e71335ee7c76756f Mon Sep 17 00:00:00 2001
From: apsakinmukomi <oluwaseun.akinmukomi@africaprudential.com>
Date: Wed, 29 Mar 2023 17:54:46 +0100
Subject: [PATCH] - Added code for ap50 calculation

---
 main_lost.py | 22 ++++++++++++++++++++--
 1 file changed, 20 insertions(+), 2 deletions(-)

diff --git a/main_lost.py b/main_lost.py
index 5db52cf..e89e09a 100755
--- a/main_lost.py
+++ b/main_lost.py
@@ -334,6 +334,12 @@ if __name__ == "__main__":
             if args.no_evaluation:
                 continue
 
+        # Initialize variables for AP50 calculation
+        tp = 0
+        fp = 0
+        total_gt_boxes = len(gt_bbxs)
+        ap50 = 0
+
             # Compare prediction to GT boxes
         for pred in preds:
             if len(preds) == 0:
@@ -345,19 +351,31 @@ if __name__ == "__main__":
             ious = bbox_iou(torch.from_numpy(pred), torch.from_numpy(np.asarray(gt_bbxs)))
 
             # TODO: This calculates the corloc
-            # we need to calculate the AP50
             if torch.any(ious >= 0.50):
                 #corloc[im_id] = 1
                 corloc[im_id] = 0
             for i in ious:
                 if i >= 0.50:
-                    corloc[im_id] += 1 
+                    corloc[im_id] += 1
+
+            # Count true positives and false positives at IoU threshold of 0.5
+            if torch.any(ious >= 0.50):
+                tp += 1
+            else:
+                fp += 1
 
         cnt += len(gt_bbxs)
         
         if cnt % 50 == 0:
             pbar.set_description(f"Found {int(np.sum(corloc))}/{cnt}")
 
+        # Calculate precision and recall at IoU threshold of 0.5
+        precision = tp / (tp + fp)
+        recall = tp / total_gt_boxes
+
+        # Calculate AP50 as average precision at IoU threshold of 0.5
+        ap50 = precision * recall
+        print(f"AP50: {ap50:.2f}")
 
     # Save predicted bounding boxes
     if args.save_predictions:
-- 
GitLab