From 21a8f561cb7a43ee31ad043c7cdf5651dbce8c81 Mon Sep 17 00:00:00 2001 From: Spencer Delcore <sdelcore@uwaterloo.ca> Date: Sat, 25 Mar 2023 14:13:51 -0400 Subject: [PATCH] fix bugs in lost for coco finally --- datasets.py | 10 ++++++---- visualizations.py | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/datasets.py b/datasets.py index bc72a73..95264c9 100755 --- a/datasets.py +++ b/datasets.py @@ -80,6 +80,9 @@ class Dataset: with open(self.sel20k, "r") as f: self.sel_20k = f.readlines() self.sel_20k = [s.replace("\n", "") for s in self.sel_20k] + # im20k has the ids for each image in an array that matches sel_20k + # sel_20k has an array for images as train2014/COCO_####### + self.im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in self.sel_20k] else: raise ValueError("Unknown dataset.") @@ -120,8 +123,7 @@ class Dataset: if "VOC" in self.dataset_name: image = skimage.io.imread(f"datasets/VOC{self.year}/JPEGImages/{im_name}") elif "COCO" in self.dataset_name: - im_path = self.sel_20k[int(im_name)] - image = skimage.io.imread(f"datasets/COCO/images/{im_path}") + image = skimage.io.imread(f"datasets/COCO/images/{im_name}") else: raise ValueError("Unkown dataset.") return image @@ -133,8 +135,8 @@ class Dataset: if "VOC" in self.dataset_name: im_name = inp["annotation"]["filename"] elif "COCO" in self.dataset_name: - im_name = str(inp[0]["image_id"]) - + im_id = self.im20k.index(str(inp[0]["image_id"])) + im_name = self.sel_20k[im_id] return im_name def extract_gt(self, targets, im_name): diff --git a/visualizations.py b/visualizations.py index 0ef4009..8284f5e 100755 --- a/visualizations.py +++ b/visualizations.py @@ -19,7 +19,7 @@ import numpy as np import torch.nn as nn from PIL import Image from random import * - +import os import matplotlib.pyplot as plt def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, plot_seed=False, is_gt=False): @@ -56,6 +56,7 @@ def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, ) if im_name is not None: pltname = f"{vis_folder}/LOST_{im_name}.png" + os.system('mkdir -p ' + os.path.dirname(pltname)) Image.fromarray(image).save(pltname) #print(f"Predictions saved at {pltname}.") @@ -139,5 +140,6 @@ def visualize_seed_expansion(image, pred, seed, pred_seed, scales, dims, vis_fol image[start_1:end_1, start_2:end_2, 2] = 41 pltname = f"{vis_folder}/LOST_seed_expansion_{im_name}.png" + os.system('mkdir -p' + os.path.dirname(pltname)) Image.fromarray(image).save(pltname) print(f"Image saved at {pltname}.") -- GitLab