From 21a8f561cb7a43ee31ad043c7cdf5651dbce8c81 Mon Sep 17 00:00:00 2001
From: Spencer Delcore <sdelcore@uwaterloo.ca>
Date: Sat, 25 Mar 2023 14:13:51 -0400
Subject: [PATCH] fix bugs in lost for coco finally

---
 datasets.py       | 10 ++++++----
 visualizations.py |  4 +++-
 2 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/datasets.py b/datasets.py
index bc72a73..95264c9 100755
--- a/datasets.py
+++ b/datasets.py
@@ -80,6 +80,9 @@ class Dataset:
             with open(self.sel20k, "r") as f:
                 self.sel_20k = f.readlines()
                 self.sel_20k = [s.replace("\n", "") for s in self.sel_20k]
+            # im20k has the ids for each image in an array that matches sel_20k
+            # sel_20k has an array for images as train2014/COCO_#######
+            self.im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in self.sel_20k]
             
         else:
             raise ValueError("Unknown dataset.")
@@ -120,8 +123,7 @@ class Dataset:
         if "VOC" in self.dataset_name:
             image = skimage.io.imread(f"datasets/VOC{self.year}/JPEGImages/{im_name}")
         elif "COCO" in self.dataset_name:
-            im_path = self.sel_20k[int(im_name)]
-            image = skimage.io.imread(f"datasets/COCO/images/{im_path}")
+            image = skimage.io.imread(f"datasets/COCO/images/{im_name}")
         else:
             raise ValueError("Unkown dataset.")
         return image
@@ -133,8 +135,8 @@ class Dataset:
         if "VOC" in self.dataset_name:
             im_name = inp["annotation"]["filename"]
         elif "COCO" in self.dataset_name:
-            im_name = str(inp[0]["image_id"])
-
+            im_id = self.im20k.index(str(inp[0]["image_id"]))
+            im_name = self.sel_20k[im_id]
         return im_name
 
     def extract_gt(self, targets, im_name):
diff --git a/visualizations.py b/visualizations.py
index 0ef4009..8284f5e 100755
--- a/visualizations.py
+++ b/visualizations.py
@@ -19,7 +19,7 @@ import numpy as np
 import torch.nn as nn
 from PIL import Image
 from random import *
-
+import os
 import matplotlib.pyplot as plt
 
 def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, plot_seed=False, is_gt=False):
@@ -56,6 +56,7 @@ def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name,
         )
     if im_name is not None:
         pltname = f"{vis_folder}/LOST_{im_name}.png"
+        os.system('mkdir -p ' + os.path.dirname(pltname))
         Image.fromarray(image).save(pltname)
     #print(f"Predictions saved at {pltname}.")
 
@@ -139,5 +140,6 @@ def visualize_seed_expansion(image, pred, seed, pred_seed, scales, dims, vis_fol
     image[start_1:end_1, start_2:end_2, 2] = 41
 
     pltname = f"{vis_folder}/LOST_seed_expansion_{im_name}.png"
+    os.system('mkdir -p' + os.path.dirname(pltname))
     Image.fromarray(image).save(pltname)
     print(f"Image saved at {pltname}.")
-- 
GitLab