Skip to content
Snippets Groups Projects
Commit fd11fe67 authored by Spencer Delcore's avatar Spencer Delcore
Browse files

added code to check gt with predictions, need to validate

parent 7b765c9a
No related branches found
No related tags found
No related merge requests found
...@@ -33,8 +33,8 @@ class GenericDataset: ...@@ -33,8 +33,8 @@ class GenericDataset:
self.name = name self.name = name
self.data_arr = data_arr self.data_arr = data_arr
if self.name == 'KITTI': if self.name == 'KITTI':
#with open(r"/root/lost/datasets/kitti_labels.pkl", "rb") as input_file: # TODO need to not hardcode
with open(r"/root/lost/Kitti2Coco/train/kitti_labels.pkl", "rb") as input_file: with open(r"tools/Kitti2Coco/kitti_train_labels.pkl", "rb") as input_file:
self.annots = pickle.load(input_file) self.annots = pickle.load(input_file)
print(len(self.data_arr)) print(len(self.data_arr))
...@@ -61,8 +61,6 @@ class GenericDataset: ...@@ -61,8 +61,6 @@ class GenericDataset:
return [img, self.data_arr[i]] return [img, self.data_arr[i]]
if self.annots != None: if self.annots != None:
if self.name == 'KITTI': if self.name == 'KITTI':
print(self.annots[im_name]['annotations'])
print(self.annots[im_name])
return [img, self.data_arr[i], self.annots[im_name]['annotations'], img.size, self.annots[im_name]] return [img, self.data_arr[i], self.annots[im_name]['annotations'], img.size, self.annots[im_name]]
return [img, self.data_arr[i], self.annots[im_name], img.size] return [img, self.data_arr[i], self.annots[im_name], img.size]
...@@ -74,8 +72,14 @@ class GenericDataset: ...@@ -74,8 +72,14 @@ class GenericDataset:
return None return None
if self.name == 'KITTI': if self.name == 'KITTI':
return None # TODO need to handle returning annotations annots = self.annots[im_name]['annotations']
gt_bbxs = []
gt_clss = []
for gt in annots:
gt_bbxs.append(gt['bbox'])
gt_clss.append(gt['category_id'])
return np.asarray(gt_bbxs), gt_clss
im = self.annots[im_name] im = self.annots[im_name]
# {"labels": ['bbox_x1','bbox_y1','bbox_x2','bbox_y2','class', 'test']} # {"labels": ['bbox_x1','bbox_y1','bbox_x2','bbox_y2','class', 'test']}
gt_bbxs = im[0:4] gt_bbxs = im[0:4]
......
...@@ -128,7 +128,7 @@ if __name__ == "__main__": ...@@ -128,7 +128,7 @@ if __name__ == "__main__":
if args.image_path is not None: if args.image_path is not None:
dataset = ImageDataset(args.image_path) dataset = ImageDataset(args.image_path)
elif args.dataset == "KITTI": elif args.dataset == "KITTI":
dataset = ImageFolderDataset("$KITTI_ROOT/training/image_2/") # TODO dont hard code dataset = ImageFolderDataset('KITTI', os.environ.get('$KITTI_ROOT','/root/kitti')+'/training/image_2/') # TODO dont hard code
else: else:
dataset = Dataset(args.dataset, args.set, args.no_hard) dataset = Dataset(args.dataset, args.set, args.no_hard)
......
OUTPUT_PATH=/root/kitti/lost_output OUTPUT_PATH=/root/lost/outputs/kitti
DINO_ARCH=vit_base DINO_ARCH=vit_base
LOST_FEATURES=k LOST_FEATURES=k
......
...@@ -43,7 +43,7 @@ def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name, ...@@ -43,7 +43,7 @@ def visualize_predictions(image, pred, seed, scales, dims, vis_folder, im_name,
(int(pred[2]), int(pred[3])), (int(pred[2]), int(pred[3])),
(0, 255, 0), 3, (0, 255, 0), 3,
) )
print("image.shape:",image.shape, "\npred_box: [x1,y1,x2,y2]", pred) #print("image.shape:",image.shape, "\npred_box: [x1,y1,x2,y2]", pred)
# Plot the seed # Plot the seed
if plot_seed: if plot_seed:
s_ = np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap)) s_ = np.unravel_index(seed.cpu().numpy(), (w_featmap, h_featmap))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment