diff --git a/main_lost.py b/main_lost.py
index 5db52cf6473f053710559ec2d22eafd8c785864d..cdd61624a88485575a89471428ad0c6f96a590e2 100755
--- a/main_lost.py
+++ b/main_lost.py
@@ -1,377 +1,395 @@
-# Copyright 2021 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#      http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import argparse
-import random
-import pickle
-
-import torch
-import torch.nn as nn
-import numpy as np
-
-from tqdm import tqdm
-from PIL import Image
-
-from networks import get_model
-from datasets import ImageDataset, Dataset, bbox_iou
-from visualizations import visualize_fms, visualize_predictions, visualize_seed_expansion
-from object_discovery import lost, detect_box, dino_seg
-
-if __name__ == "__main__":
-    parser = argparse.ArgumentParser("Unsupervised object discovery with LOST.")
-    parser.add_argument(
-        "--arch",
-        default="vit_small",
-        type=str,
-        choices=[
-            "vit_tiny",
-            "vit_small",
-            "vit_base",
-            "resnet50",
-            "vgg16_imagenet",
-            "resnet50_imagenet",
-        ],
-        help="Model architecture.",
-    )
-    parser.add_argument(
-        "--patch_size", default=16, type=int, help="Patch resolution of the model."
-    )
-
-    # Use a dataset
-    parser.add_argument(
-        "--dataset",
-        default="VOC07",
-        type=str,
-        choices=[None, "VOC07", "VOC12", "COCO20k"],
-        help="Dataset name.",
-    )
-    parser.add_argument(
-        "--set",
-        default="train",
-        type=str,
-        choices=["val", "train", "trainval", "test"],
-        help="Path of the image to load.",
-    )
-    # Or use a single image
-    parser.add_argument(
-        "--image_path",
-        type=str,
-        default=None,
-        help="If want to apply only on one image, give file path.",
-    )
-
-    # Folder used to output visualizations and 
-    parser.add_argument(
-        "--output_dir", type=str, default="outputs", help="Output directory to store predictions and visualizations."
-    )
-
-    # Evaluation setup
-    parser.add_argument("--no_hard", action="store_true", help="Only used in the case of the VOC_all setup (see the paper).")
-    parser.add_argument("--no_evaluation", action="store_true", help="Compute the evaluation.")
-    parser.add_argument("--save_predictions", default=True, type=bool, help="Save predicted bouding boxes.")
-    parser.add_argument("--num_init_seeds", default=1, type=int, help="Number of initial seeds to expand from.")
-
-    # Visualization
-    parser.add_argument(
-        "--visualize",
-        type=str,
-        choices=["fms", "seed_expansion", "pred", None],
-        default=None,
-        help="Select the different type of visualizations.",
-    )
-
-    # For ResNet dilation
-    parser.add_argument("--resnet_dilate", type=int, default=2, help="Dilation level of the resnet model.")
-
-    # LOST parameters
-    parser.add_argument(
-        "--which_features",
-        type=str,
-        default="k",
-        choices=["k", "q", "v"],
-        help="Which features to use",
-    )
-    parser.add_argument(
-        "--k_patches",
-        type=int,
-        default=100,
-        help="Number of patches with the lowest degree considered."
-    )
-
-    # Use dino-seg proposed method
-    parser.add_argument("--dinoseg", action="store_true", help="Apply DINO-seg baseline.")
-    parser.add_argument("--dinoseg_head", type=int, default=4)
-
-    args = parser.parse_args()
-
-    if args.image_path is not None:
-        args.save_predictions = False
-        args.no_evaluation = True
-        args.dataset = None
-
-    # -------------------------------------------------------------------------------------------------------
-    # Dataset
-
-    # If an image_path is given, apply the method only to the image
-    if args.image_path is not None:
-        dataset = ImageDataset(args.image_path)
-    else:
-        dataset = Dataset(args.dataset, args.set, args.no_hard)
-
-    # -------------------------------------------------------------------------------------------------------
-    # Model
-    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
-    print("Running on device:", device)
-    model = get_model(args.arch, args.patch_size, args.resnet_dilate, device)
-
-    # -------------------------------------------------------------------------------------------------------
-    # Directories
-    if args.image_path is None:
-        args.output_dir = os.path.join(args.output_dir, dataset.name)
-    os.makedirs(args.output_dir, exist_ok=True)
-
-    # Naming
-    if args.dinoseg:
-        # Experiment with the baseline DINO-seg
-        if "vit" not in args.arch:
-            raise ValueError("DINO-seg can only be applied to tranformer networks.")
-        exp_name = f"{args.arch}-{args.patch_size}_dinoseg-head{args.dinoseg_head}"
-    else:
-        # Experiment with LOST
-        exp_name = f"LOST-{args.arch}"
-        if "resnet" in args.arch:
-            exp_name += f"dilate{args.resnet_dilate}"
-        elif "vit" in args.arch:
-            exp_name += f"{args.patch_size}_{args.which_features}"
-
-    print(f"Running LOST on the dataset {dataset.name} (exp: {exp_name})")
-
-    # Visualization 
-    if args.visualize:
-        vis_folder = f"{args.output_dir}/visualizations/{exp_name}"
-        os.makedirs(vis_folder, exist_ok=True)
-    
-    # -------------------------------------------------------------------------------------------------------
-    # Loop over images
-    preds_dict = {}
-    gt_dict = {}
-    cnt = 0
-    corloc = np.zeros(len(dataset.dataloader))
-    
-    pbar = tqdm(dataset.dataloader)
-    for im_id, inp in enumerate(pbar):
-        torch.cuda.empty_cache()
-        # ------------ IMAGE PROCESSING -------------------------------------------
-        img = inp[0]
-        init_image_size = img.shape
-
-        # Get the name of the image
-        im_name = dataset.get_image_name(inp[1])
-
-        # Pass in case of no gt boxes in the image
-        if im_name is None:
-            continue
-
-        # Padding the image with zeros to fit multiple of patch-size
-        size_im = (
-            img.shape[0],
-            int(np.ceil(img.shape[1] / args.patch_size) * args.patch_size),
-            int(np.ceil(img.shape[2] / args.patch_size) * args.patch_size),
-        )
-        paded = torch.zeros(size_im)
-        paded[:, : img.shape[1], : img.shape[2]] = img
-        img = paded
-
-        # Move to gpu
-        if device == torch.device("cuda"):
-            img = img.cuda(non_blocking=True)
-        
-        # Size for transformers
-        w_featmap = img.shape[-2] // args.patch_size
-        h_featmap = img.shape[-1] // args.patch_size
-
-        # ------------ GROUND-TRUTH -------------------------------------------
-        if not args.no_evaluation:
-            gt_bbxs, gt_cls = dataset.extract_gt(inp[1], im_name)
-
-            if gt_bbxs is not None:
-                # Discard images with no gt annotations
-                # Happens only in the case of VOC07 and VOC12
-                if gt_bbxs.shape[0] == 0 and args.no_hard:
-                    continue
-
-        # ------------ EXTRACT FEATURES -------------------------------------------
-        with torch.no_grad():
-
-            # ------------ FORWARD PASS -------------------------------------------
-            if "vit" in args.arch:
-                # Store the outputs of qkv layer from the last attention layer
-                feat_out = {}
-                def hook_fn_forward_qkv(module, input, output):
-                    feat_out["qkv"] = output
-                model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv)
-
-                # Forward pass in the model
-                attentions = model.get_last_selfattention(img[None, :, :, :])
-
-                # Scaling factor
-                scales = [args.patch_size, args.patch_size]
-
-                # Dimensions
-                nb_im = attentions.shape[0]  # Batch size
-                nh = attentions.shape[1]  # Number of heads
-                nb_tokens = attentions.shape[2]  # Number of tokens
-
-                # Baseline: compute DINO segmentation technique proposed in the DINO paper
-                # and select the biggest component
-                if args.dinoseg:
-                    pred = dino_seg(attentions, (w_featmap, h_featmap), args.patch_size, head=args.dinoseg_head)
-                    pred = np.asarray(pred)
-                else:
-                    # Extract the qkv features of the last attention layer
-                    qkv = (
-                        feat_out["qkv"]
-                        .reshape(nb_im, nb_tokens, 3, nh, -1 // nh)
-                        .permute(2, 0, 3, 1, 4)
-                    )
-                    q, k, v = qkv[0], qkv[1], qkv[2]
-                    k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
-                    q = q.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
-                    v = v.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
-
-                    # Modality selection
-                    if args.which_features == "k":
-                        feats = k[:, 1:, :]
-                    elif args.which_features == "q":
-                        feats = q[:, 1:, :]
-                    elif args.which_features == "v":
-                        feats = v[:, 1:, :]
-
-            elif "resnet" in args.arch:
-                x = model.forward(img[None, :, :, :])
-                d, w_featmap, h_featmap = x.shape[1:]
-                feats = x.reshape((1, d, -1)).transpose(2, 1)
-                # Apply layernorm
-                layernorm = nn.LayerNorm(feats.size()[1:]).to(device)
-                feats = layernorm(feats)
-                # Scaling factor
-                scales = [
-                    float(img.shape[1]) / x.shape[2],
-                    float(img.shape[2]) / x.shape[3],
-                ]
-            elif "vgg16" in args.arch:
-                x = model.forward(img[None, :, :, :])
-                d, w_featmap, h_featmap = x.shape[1:]
-                feats = x.reshape((1, d, -1)).transpose(2, 1)
-                # Apply layernorm
-                layernorm = nn.LayerNorm(feats.size()[1:]).to(device)
-                feats = layernorm(feats)
-                # Scaling factor
-                scales = [
-                    float(img.shape[1]) / x.shape[2],
-                    float(img.shape[2]) / x.shape[3],
-                ]
-            else:
-                raise ValueError("Unknown model.")
-
-        # ------------ Apply LOST -------------------------------------------
-        if not args.dinoseg:
-            preds, A, scores, seeds = lost(
-            feats,
-            [w_featmap, h_featmap],
-            scales,
-            init_image_size,
-            k_patches=args.k_patches,
-            num_init_seeds=args.num_init_seeds
-            )
-
-            # ------------ Visualizations -------------------------------------------
-            if args.visualize == "fms":
-                for i, x in enumerate(zip(preds, seeds)):
-                    pred, seed = x
-                    visualize_fms(A.clone().cpu().numpy(), seed, scores, [w_featmap, h_featmap], scales, vis_folder, im_name+'_'+str(i))
-
-            elif args.visualize == "seed_expansion":
-                for i, x in enumerate(zip(preds, seeds)):
-                    pred, seed = x
-                    image = dataset.load_image(im_name)
-
-                    # Before expansion
-                    pred_seed, _ = detect_box(
-                        A[seed, :],
-                        seed,
-                        [w_featmap, h_featmap],
-                        scales=scales,
-                        initial_im_size=init_image_size[1:],
-                    )
-                    visualize_seed_expansion(image, pred, seed, pred_seed, scales, [w_featmap, h_featmap], vis_folder, im_name+'_'+str(i))
-
-            elif args.visualize == "pred":
-                image = dataset.load_image(im_name)
-                for i, x in enumerate(zip(preds, seeds)):
-                    pred, seed = x
-                    image_name = None
-                    if i == len(preds) -1:
-                        image_name = im_name
-                    visualize_predictions(image, pred, seed, scales, [w_featmap, h_featmap], vis_folder, image_name)
-
-            # Save the prediction
-            #preds_dict[im_name] = preds
-            
-            # Evaluation
-            if args.no_evaluation:
-                continue
-
-            # Compare prediction to GT boxes
-        for pred in preds:
-            if len(preds) == 0:
-                continue
-
-            if len(gt_bbxs) == 0:
-                break # TODO: should do something else, should skip iou but count towards FP if pred exists
-
-            ious = bbox_iou(torch.from_numpy(pred), torch.from_numpy(np.asarray(gt_bbxs)))
-
-            # TODO: This calculates the corloc
-            # we need to calculate the AP50
-            if torch.any(ious >= 0.50):
-                #corloc[im_id] = 1
-                corloc[im_id] = 0
-            for i in ious:
-                if i >= 0.50:
-                    corloc[im_id] += 1 
-
-        cnt += len(gt_bbxs)
-        
-        if cnt % 50 == 0:
-            pbar.set_description(f"Found {int(np.sum(corloc))}/{cnt}")
-
-
-    # Save predicted bounding boxes
-    if args.save_predictions:
-        folder = f"{args.output_dir}/{exp_name}"
-        os.makedirs(folder, exist_ok=True)
-        filename = os.path.join(folder, "preds.pkl")
-        with open(filename, "wb") as f:
-            pickle.dump(preds_dict, f)
-        print("Predictions saved at %s" % filename)
-
-    # Evaluate
-    if not args.no_evaluation:
-        print(f"corloc: {100*np.sum(corloc)/cnt:.2f} ({int(np.sum(corloc))}/{cnt})")
-        result_file = os.path.join(folder, 'results.txt')
-        with open(result_file, 'w') as f:
-            f.write('corloc,%.1f,,\n'%(100*np.sum(corloc)/cnt))
-        print('File saved at %s'%result_file)
+# Copyright 2021 - Valeo Comfort and Driving Assistance - Oriane Siméoni @ valeo.ai
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import random
+import pickle
+
+import torch
+import torch.nn as nn
+import numpy as np
+
+from tqdm import tqdm
+from PIL import Image
+
+from networks import get_model
+from datasets import ImageDataset, Dataset, bbox_iou
+from visualizations import visualize_fms, visualize_predictions, visualize_seed_expansion
+from object_discovery import lost, detect_box, dino_seg
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser("Unsupervised object discovery with LOST.")
+    parser.add_argument(
+        "--arch",
+        default="vit_small",
+        type=str,
+        choices=[
+            "vit_tiny",
+            "vit_small",
+            "vit_base",
+            "resnet50",
+            "vgg16_imagenet",
+            "resnet50_imagenet",
+        ],
+        help="Model architecture.",
+    )
+    parser.add_argument(
+        "--patch_size", default=16, type=int, help="Patch resolution of the model."
+    )
+
+    # Use a dataset
+    parser.add_argument(
+        "--dataset",
+        default="VOC07",
+        type=str,
+        choices=[None, "VOC07", "VOC12", "COCO20k"],
+        help="Dataset name.",
+    )
+    parser.add_argument(
+        "--set",
+        default="train",
+        type=str,
+        choices=["val", "train", "trainval", "test"],
+        help="Path of the image to load.",
+    )
+    # Or use a single image
+    parser.add_argument(
+        "--image_path",
+        type=str,
+        default=None,
+        help="If want to apply only on one image, give file path.",
+    )
+
+    # Folder used to output visualizations and 
+    parser.add_argument(
+        "--output_dir", type=str, default="outputs", help="Output directory to store predictions and visualizations."
+    )
+
+    # Evaluation setup
+    parser.add_argument("--no_hard", action="store_true", help="Only used in the case of the VOC_all setup (see the paper).")
+    parser.add_argument("--no_evaluation", action="store_true", help="Compute the evaluation.")
+    parser.add_argument("--save_predictions", default=True, type=bool, help="Save predicted bouding boxes.")
+    parser.add_argument("--num_init_seeds", default=1, type=int, help="Number of initial seeds to expand from.")
+
+    # Visualization
+    parser.add_argument(
+        "--visualize",
+        type=str,
+        choices=["fms", "seed_expansion", "pred", None],
+        default=None,
+        help="Select the different type of visualizations.",
+    )
+
+    # For ResNet dilation
+    parser.add_argument("--resnet_dilate", type=int, default=2, help="Dilation level of the resnet model.")
+
+    # LOST parameters
+    parser.add_argument(
+        "--which_features",
+        type=str,
+        default="k",
+        choices=["k", "q", "v"],
+        help="Which features to use",
+    )
+    parser.add_argument(
+        "--k_patches",
+        type=int,
+        default=100,
+        help="Number of patches with the lowest degree considered."
+    )
+
+    # Use dino-seg proposed method
+    parser.add_argument("--dinoseg", action="store_true", help="Apply DINO-seg baseline.")
+    parser.add_argument("--dinoseg_head", type=int, default=4)
+
+    args = parser.parse_args()
+
+    if args.image_path is not None:
+        args.save_predictions = False
+        args.no_evaluation = True
+        args.dataset = None
+
+    # -------------------------------------------------------------------------------------------------------
+    # Dataset
+
+    # If an image_path is given, apply the method only to the image
+    if args.image_path is not None:
+        dataset = ImageDataset(args.image_path)
+    else:
+        dataset = Dataset(args.dataset, args.set, args.no_hard)
+
+    # -------------------------------------------------------------------------------------------------------
+    # Model
+    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+    print("Running on device:", device)
+    model = get_model(args.arch, args.patch_size, args.resnet_dilate, device)
+
+    # -------------------------------------------------------------------------------------------------------
+    # Directories
+    if args.image_path is None:
+        args.output_dir = os.path.join(args.output_dir, dataset.name)
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    # Naming
+    if args.dinoseg:
+        # Experiment with the baseline DINO-seg
+        if "vit" not in args.arch:
+            raise ValueError("DINO-seg can only be applied to tranformer networks.")
+        exp_name = f"{args.arch}-{args.patch_size}_dinoseg-head{args.dinoseg_head}"
+    else:
+        # Experiment with LOST
+        exp_name = f"LOST-{args.arch}"
+        if "resnet" in args.arch:
+            exp_name += f"dilate{args.resnet_dilate}"
+        elif "vit" in args.arch:
+            exp_name += f"{args.patch_size}_{args.which_features}"
+
+    print(f"Running LOST on the dataset {dataset.name} (exp: {exp_name})")
+
+    # Visualization 
+    if args.visualize:
+        vis_folder = f"{args.output_dir}/visualizations/{exp_name}"
+        os.makedirs(vis_folder, exist_ok=True)
+    
+    # -------------------------------------------------------------------------------------------------------
+    # Loop over images
+    preds_dict = {}
+    gt_dict = {}
+    cnt = 0
+    corloc = np.zeros(len(dataset.dataloader))
+    
+    pbar = tqdm(dataset.dataloader)
+    for im_id, inp in enumerate(pbar):
+        torch.cuda.empty_cache()
+        # ------------ IMAGE PROCESSING -------------------------------------------
+        img = inp[0]
+        init_image_size = img.shape
+
+        # Get the name of the image
+        im_name = dataset.get_image_name(inp[1])
+
+        # Pass in case of no gt boxes in the image
+        if im_name is None:
+            continue
+
+        # Padding the image with zeros to fit multiple of patch-size
+        size_im = (
+            img.shape[0],
+            int(np.ceil(img.shape[1] / args.patch_size) * args.patch_size),
+            int(np.ceil(img.shape[2] / args.patch_size) * args.patch_size),
+        )
+        paded = torch.zeros(size_im)
+        paded[:, : img.shape[1], : img.shape[2]] = img
+        img = paded
+
+        # Move to gpu
+        if device == torch.device("cuda"):
+            img = img.cuda(non_blocking=True)
+        
+        # Size for transformers
+        w_featmap = img.shape[-2] // args.patch_size
+        h_featmap = img.shape[-1] // args.patch_size
+
+        # ------------ GROUND-TRUTH -------------------------------------------
+        if not args.no_evaluation:
+            gt_bbxs, gt_cls = dataset.extract_gt(inp[1], im_name)
+
+            if gt_bbxs is not None:
+                # Discard images with no gt annotations
+                # Happens only in the case of VOC07 and VOC12
+                if gt_bbxs.shape[0] == 0 and args.no_hard:
+                    continue
+
+        # ------------ EXTRACT FEATURES -------------------------------------------
+        with torch.no_grad():
+
+            # ------------ FORWARD PASS -------------------------------------------
+            if "vit" in args.arch:
+                # Store the outputs of qkv layer from the last attention layer
+                feat_out = {}
+                def hook_fn_forward_qkv(module, input, output):
+                    feat_out["qkv"] = output
+                model._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook(hook_fn_forward_qkv)
+
+                # Forward pass in the model
+                attentions = model.get_last_selfattention(img[None, :, :, :])
+
+                # Scaling factor
+                scales = [args.patch_size, args.patch_size]
+
+                # Dimensions
+                nb_im = attentions.shape[0]  # Batch size
+                nh = attentions.shape[1]  # Number of heads
+                nb_tokens = attentions.shape[2]  # Number of tokens
+
+                # Baseline: compute DINO segmentation technique proposed in the DINO paper
+                # and select the biggest component
+                if args.dinoseg:
+                    pred = dino_seg(attentions, (w_featmap, h_featmap), args.patch_size, head=args.dinoseg_head)
+                    pred = np.asarray(pred)
+                else:
+                    # Extract the qkv features of the last attention layer
+                    qkv = (
+                        feat_out["qkv"]
+                        .reshape(nb_im, nb_tokens, 3, nh, -1 // nh)
+                        .permute(2, 0, 3, 1, 4)
+                    )
+                    q, k, v = qkv[0], qkv[1], qkv[2]
+                    k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
+                    q = q.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
+                    v = v.transpose(1, 2).reshape(nb_im, nb_tokens, -1)
+
+                    # Modality selection
+                    if args.which_features == "k":
+                        feats = k[:, 1:, :]
+                    elif args.which_features == "q":
+                        feats = q[:, 1:, :]
+                    elif args.which_features == "v":
+                        feats = v[:, 1:, :]
+
+            elif "resnet" in args.arch:
+                x = model.forward(img[None, :, :, :])
+                d, w_featmap, h_featmap = x.shape[1:]
+                feats = x.reshape((1, d, -1)).transpose(2, 1)
+                # Apply layernorm
+                layernorm = nn.LayerNorm(feats.size()[1:]).to(device)
+                feats = layernorm(feats)
+                # Scaling factor
+                scales = [
+                    float(img.shape[1]) / x.shape[2],
+                    float(img.shape[2]) / x.shape[3],
+                ]
+            elif "vgg16" in args.arch:
+                x = model.forward(img[None, :, :, :])
+                d, w_featmap, h_featmap = x.shape[1:]
+                feats = x.reshape((1, d, -1)).transpose(2, 1)
+                # Apply layernorm
+                layernorm = nn.LayerNorm(feats.size()[1:]).to(device)
+                feats = layernorm(feats)
+                # Scaling factor
+                scales = [
+                    float(img.shape[1]) / x.shape[2],
+                    float(img.shape[2]) / x.shape[3],
+                ]
+            else:
+                raise ValueError("Unknown model.")
+
+        # ------------ Apply LOST -------------------------------------------
+        if not args.dinoseg:
+            preds, A, scores, seeds = lost(
+            feats,
+            [w_featmap, h_featmap],
+            scales,
+            init_image_size,
+            k_patches=args.k_patches,
+            num_init_seeds=args.num_init_seeds
+            )
+
+            # ------------ Visualizations -------------------------------------------
+            if args.visualize == "fms":
+                for i, x in enumerate(zip(preds, seeds)):
+                    pred, seed = x
+                    visualize_fms(A.clone().cpu().numpy(), seed, scores, [w_featmap, h_featmap], scales, vis_folder, im_name+'_'+str(i))
+
+            elif args.visualize == "seed_expansion":
+                for i, x in enumerate(zip(preds, seeds)):
+                    pred, seed = x
+                    image = dataset.load_image(im_name)
+
+                    # Before expansion
+                    pred_seed, _ = detect_box(
+                        A[seed, :],
+                        seed,
+                        [w_featmap, h_featmap],
+                        scales=scales,
+                        initial_im_size=init_image_size[1:],
+                    )
+                    visualize_seed_expansion(image, pred, seed, pred_seed, scales, [w_featmap, h_featmap], vis_folder, im_name+'_'+str(i))
+
+            elif args.visualize == "pred":
+                image = dataset.load_image(im_name)
+                for i, x in enumerate(zip(preds, seeds)):
+                    pred, seed = x
+                    image_name = None
+                    if i == len(preds) -1:
+                        image_name = im_name
+                    visualize_predictions(image, pred, seed, scales, [w_featmap, h_featmap], vis_folder, image_name)
+
+            # Save the prediction
+            #preds_dict[im_name] = preds
+            
+            # Evaluation
+            if args.no_evaluation:
+                continue
+
+        # Initialize variables for AP50 calculation
+        tp = 0
+        fp = 0
+        total_gt_boxes = len(gt_bbxs)
+        ap50 = 0
+
+            # Compare prediction to GT boxes
+        for pred in preds:
+            if len(preds) == 0:
+                continue
+
+            if len(gt_bbxs) == 0:
+                break # TODO: should do something else, should skip iou but count towards FP if pred exists
+
+            ious = bbox_iou(torch.from_numpy(pred), torch.from_numpy(np.asarray(gt_bbxs)))
+
+            # TODO: This calculates the corloc
+            if torch.any(ious >= 0.50):
+                #corloc[im_id] = 1
+                corloc[im_id] = 0
+            for i in ious:
+                if i >= 0.50:
+                    corloc[im_id] += 1
+
+            # Count true positives and false positives at IoU threshold of 0.5
+            if torch.any(ious >= 0.50):
+                tp += 1
+            else:
+                fp += 1
+
+        cnt += len(gt_bbxs)
+
+        if cnt % 50 == 0:
+            pbar.set_description(f"Found {int(np.sum(corloc))}/{cnt}")
+
+        # Calculate precision and recall at IoU threshold of 0.5
+        precision = tp / (tp + fp)
+        recall = tp / total_gt_boxes
+
+        # Calculate AP50 as average precision at IoU threshold of 0.5
+        ap50 = precision * recall
+        print(f"AP50: {ap50:.2f}")
+
+    # Save predicted bounding boxes
+    if args.save_predictions:
+        folder = f"{args.output_dir}/{exp_name}"
+        os.makedirs(folder, exist_ok=True)
+        filename = os.path.join(folder, "preds.pkl")
+        with open(filename, "wb") as f:
+            pickle.dump(preds_dict, f)
+        print("Predictions saved at %s" % filename)
+
+    # Evaluate
+    if not args.no_evaluation:
+        print(f"corloc: {100*np.sum(corloc)/cnt:.2f} ({int(np.sum(corloc))}/{cnt})")
+        result_file = os.path.join(folder, 'results.txt')
+        with open(result_file, 'w') as f:
+            f.write('corloc,%.1f,,\n'%(100*np.sum(corloc)/cnt))
+        print('File saved at %s'%result_file)