Skip to content
Snippets Groups Projects
Commit 21a8f561 authored by Spencer Delcore's avatar Spencer Delcore
Browse files

fix bugs in lost for coco finally

parent de70d078
No related branches found
No related tags found
No related merge requests found
...@@ -80,6 +80,9 @@ class Dataset: ...@@ -80,6 +80,9 @@ class Dataset:
with open(self.sel20k, "r") as f: with open(self.sel20k, "r") as f:
self.sel_20k = f.readlines() self.sel_20k = f.readlines()
self.sel_20k = [s.replace("\n", "") for s in self.sel_20k] self.sel_20k = [s.replace("\n", "") for s in self.sel_20k]
# im20k has the ids for each image in an array that matches sel_20k
# sel_20k has an array for images as train2014/COCO_#######
self.im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in self.sel_20k]
else: else:
raise ValueError("Unknown dataset.") raise ValueError("Unknown dataset.")
...@@ -120,8 +123,7 @@ class Dataset: ...@@ -120,8 +123,7 @@ class Dataset:
if "VOC" in self.dataset_name: if "VOC" in self.dataset_name:
image = skimage.io.imread(f"datasets/VOC{self.year}/JPEGImages/{im_name}") image = skimage.io.imread(f"datasets/VOC{self.year}/JPEGImages/{im_name}")
elif "COCO" in self.dataset_name: elif "COCO" in self.dataset_name:
im_path = self.sel_20k[int(im_name)] image = skimage.io.imread(f"datasets/COCO/images/{im_name}")
image = skimage.io.imread(f"datasets/COCO/images/{im_path}")
else: else:
raise ValueError("Unkown dataset.") raise ValueError("Unkown dataset.")
return image return image
...@@ -133,8 +135,8 @@ class Dataset: ...@@ -133,8 +135,8 @@ class Dataset:
if "VOC" in self.dataset_name: if "VOC" in self.dataset_name:
im_name = inp["annotation"]["filename"] im_name = inp["annotation"]["filename"]
elif "COCO" in self.dataset_name: elif "COCO" in self.dataset_name:
im_name = str(inp[0]["image_id"]) im_id = self.im20k.index(str(inp[0]["image_id"]))
im_name = self.sel_20k[im_id]
return im_name return im_name
def extract_gt(self, targets, im_name): def extract_gt(self, targets, im_name):
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
from random import * from random import *
import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, plot_seed=False, is_gt=False): def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, plot_seed=False, is_gt=False):
...@@ -56,6 +56,7 @@ def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, ...@@ -56,6 +56,7 @@ def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name,
) )
if im_name is not None: if im_name is not None:
pltname = f"{vis_folder}/LOST_{im_name}.png" pltname = f"{vis_folder}/LOST_{im_name}.png"
os.system('mkdir -p ' + os.path.dirname(pltname))
Image.fromarray(image).save(pltname) Image.fromarray(image).save(pltname)
#print(f"Predictions saved at {pltname}.") #print(f"Predictions saved at {pltname}.")
...@@ -139,5 +140,6 @@ def visualize_seed_expansion(image, pred, seed, pred_seed, scales, dims, vis_fol ...@@ -139,5 +140,6 @@ def visualize_seed_expansion(image, pred, seed, pred_seed, scales, dims, vis_fol
image[start_1:end_1, start_2:end_2, 2] = 41 image[start_1:end_1, start_2:end_2, 2] = 41
pltname = f"{vis_folder}/LOST_seed_expansion_{im_name}.png" pltname = f"{vis_folder}/LOST_seed_expansion_{im_name}.png"
os.system('mkdir -p' + os.path.dirname(pltname))
Image.fromarray(image).save(pltname) Image.fromarray(image).save(pltname)
print(f"Image saved at {pltname}.") print(f"Image saved at {pltname}.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment