Skip to content
Snippets Groups Projects
Commit 21a8f561 authored by Spencer Delcore's avatar Spencer Delcore
Browse files

fix bugs in lost for coco finally

parent de70d078
No related branches found
No related tags found
No related merge requests found
......@@ -80,6 +80,9 @@ class Dataset:
with open(self.sel20k, "r") as f:
self.sel_20k = f.readlines()
self.sel_20k = [s.replace("\n", "") for s in self.sel_20k]
# im20k has the ids for each image in an array that matches sel_20k
# sel_20k has an array for images as train2014/COCO_#######
self.im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in self.sel_20k]
else:
raise ValueError("Unknown dataset.")
......@@ -120,8 +123,7 @@ class Dataset:
if "VOC" in self.dataset_name:
image = skimage.io.imread(f"datasets/VOC{self.year}/JPEGImages/{im_name}")
elif "COCO" in self.dataset_name:
im_path = self.sel_20k[int(im_name)]
image = skimage.io.imread(f"datasets/COCO/images/{im_path}")
image = skimage.io.imread(f"datasets/COCO/images/{im_name}")
else:
raise ValueError("Unkown dataset.")
return image
......@@ -133,8 +135,8 @@ class Dataset:
if "VOC" in self.dataset_name:
im_name = inp["annotation"]["filename"]
elif "COCO" in self.dataset_name:
im_name = str(inp[0]["image_id"])
im_id = self.im20k.index(str(inp[0]["image_id"]))
im_name = self.sel_20k[im_id]
return im_name
def extract_gt(self, targets, im_name):
......
......@@ -19,7 +19,7 @@ import numpy as np
import torch.nn as nn
from PIL import Image
from random import *
import os
import matplotlib.pyplot as plt
def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, plot_seed=False, is_gt=False):
......@@ -56,6 +56,7 @@ def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name,
)
if im_name is not None:
pltname = f"{vis_folder}/LOST_{im_name}.png"
os.system('mkdir -p ' + os.path.dirname(pltname))
Image.fromarray(image).save(pltname)
#print(f"Predictions saved at {pltname}.")
......@@ -139,5 +140,6 @@ def visualize_seed_expansion(image, pred, seed, pred_seed, scales, dims, vis_fol
image[start_1:end_1, start_2:end_2, 2] = 41
pltname = f"{vis_folder}/LOST_seed_expansion_{im_name}.png"
os.system('mkdir -p' + os.path.dirname(pltname))
Image.fromarray(image).save(pltname)
print(f"Image saved at {pltname}.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment