From fd11fe67773dfff685e2f3e772e6d1874ed0091c Mon Sep 17 00:00:00 2001 From: Spencer Delcore <sdelcore@uwaterloo.ca> Date: Fri, 17 Mar 2023 22:04:46 -0400 Subject: [PATCH] added code to check gt with predictions, need to validate --- datasets.py | 16 ++++++++++------ main_lost.py | 2 +- scripts/run-dataset.sh | 2 +- visualizations.py | 2 +- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/datasets.py b/datasets.py index 0d4e239..f04f86d 100755 --- a/datasets.py +++ b/datasets.py @@ -33,8 +33,8 @@ class GenericDataset: self.name = name self.data_arr = data_arr if self.name == 'KITTI': - #with open(r"/root/lost/datasets/kitti_labels.pkl", "rb") as input_file: - with open(r"/root/lost/Kitti2Coco/train/kitti_labels.pkl", "rb") as input_file: + # TODO need to not hardcode + with open(r"tools/Kitti2Coco/kitti_train_labels.pkl", "rb") as input_file: self.annots = pickle.load(input_file) print(len(self.data_arr)) @@ -61,8 +61,6 @@ class GenericDataset: return [img, self.data_arr[i]] if self.annots != None: if self.name == 'KITTI': - print(self.annots[im_name]['annotations']) - print(self.annots[im_name]) return [img, self.data_arr[i], self.annots[im_name]['annotations'], img.size, self.annots[im_name]] return [img, self.data_arr[i], self.annots[im_name], img.size] @@ -74,8 +72,14 @@ class GenericDataset: return None if self.name == 'KITTI': - return None # TODO need to handle returning annotations - + annots = self.annots[im_name]['annotations'] + gt_bbxs = [] + gt_clss = [] + for gt in annots: + gt_bbxs.append(gt['bbox']) + gt_clss.append(gt['category_id']) + return np.asarray(gt_bbxs), gt_clss + im = self.annots[im_name] # {"labels": ['bbox_x1','bbox_y1','bbox_x2','bbox_y2','class', 'test']} gt_bbxs = im[0:4] diff --git a/main_lost.py b/main_lost.py index 967395d..b7450b2 100755 --- a/main_lost.py +++ b/main_lost.py @@ -128,7 +128,7 @@ if __name__ == "__main__": if args.image_path is not None: dataset = ImageDataset(args.image_path) elif args.dataset == "KITTI": - dataset = ImageFolderDataset("$KITTI_ROOT/training/image_2/") # TODO dont hard code + dataset = ImageFolderDataset('KITTI', os.environ.get('$KITTI_ROOT','/root/kitti')+'/training/image_2/') # TODO dont hard code else: dataset = Dataset(args.dataset, args.set, args.no_hard) diff --git a/scripts/run-dataset.sh b/scripts/run-dataset.sh index 3566356..fa2a485 100755 --- a/scripts/run-dataset.sh +++ b/scripts/run-dataset.sh @@ -1,5 +1,5 @@ -OUTPUT_PATH=/root/kitti/lost_output +OUTPUT_PATH=/root/lost/outputs/kitti DINO_ARCH=vit_base LOST_FEATURES=k diff --git a/visualizations.py b/visualizations.py index a4ec699..0ef4009 100755 --- a/visualizations.py +++ b/visualizations.py @@ -43,7 +43,7 @@ def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, (int(pred[2]), int(pred[3])), (0, 255, 0), 3, ) - print("image.shape:",image.shape, "\npred_box: [x1,y1,x2,y2]", pred) + #print("image.shape:",image.shape, "\npred_box: [x1,y1,x2,y2]", pred) # Plot the seed if plot_seed: s_ = np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap)) -- GitLab