diff --git a/README.md b/README.md
index 8c2a24cfde3b97d8de2e442f6c032c1fa02b2df6..fb13accb459a03d14aee7a4c64da60320a49d932 100644
--- a/README.md
+++ b/README.md
@@ -197,6 +197,26 @@ cd $LOST;
 python main_corloc_evaluation.py --dataset VOC07 --set trainval --type_pred detectron --pred_file $D2/outputs/RN50_DINO_FRCNN_VOC07_CAD/inference/coco_instances_results.json
 ```
 
+#### Training LOST+CAD on COCO20k dataset
+Following are the command lines allowing to train a detector in a class-agnostic fashion on the COCO20k subset of COCO dataset.
+
+```bash
+cd $D2;
+
+# Format pseudo-boxes data to fit detectron2
+python tools/prepare_coco_LOST_CAD_pseudo_boxes_in_detectron2_format.py --pboxes $LOST/outputs/COCO20k_train/LOST-vit_small16_k/preds.pkl
+
+# Generate COCO20k CAD gt annotations
+python tools/prepare_coco_CAD_gt.py --coco_dir $LOST/datasets/COCO
+
+# Train detector (evaluation done on COCO20k CAD training set)
+python tools/train_net_for_LOST_CAD.py --num-gpus 4 --config-file ./configs/LOST/RN50_DINO_FRCNN_COCO20k_CAD.yaml DATALOADER.NUM_WORKERS 8 OUTPUT_DIR ./outputs/RN50_DINO_FRCNN_COCO20k_CAD MODEL.WEIGHTS ./data/dino_RN50_pretrain_d2_format.pkl
+
+# Corloc evaluation
+python main_corloc_evaluation.py --dataset COCO20k --type_pred detectron --pred_file $D2/outputs/RN50_DINO_FRCNN_COCO20k_CAD/inference/coco_instances_results.json
+```
+
+
 #### Evaluating LOST+CAD (corloc results)
 
 We have provided predictions of a class-agnostic Faster R-CNN model trained using LOST boxes as pseudo-gt; they are stored in the folder `data/CAD_predictions`. In order to launch the corloc evaluation, please launch the following scripts. It is to be noted that in this evaluation, only the box with the highest confidence score is considered per image. 
diff --git a/tools/configs/RN50_DINO_FRCNN_COCO20k_CAD.yaml b/tools/configs/RN50_DINO_FRCNN_COCO20k_CAD.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cdcf9b3b8d791eb7667f8e0c7931be9402b9e0e4
--- /dev/null
+++ b/tools/configs/RN50_DINO_FRCNN_COCO20k_CAD.yaml
@@ -0,0 +1,40 @@
+MODEL:
+  META_ARCHITECTURE: "GeneralizedRCNN"
+  RPN:
+    PRE_NMS_TOPK_TEST: 6000
+    POST_NMS_TOPK_TEST: 1000
+  WEIGHTS: "data/dino_RN50_pretrain_d2_format.pkl"
+  MASK_ON: False
+  RESNETS:
+    DEPTH: 50
+    STRIDE_IN_1X1: False
+    NORM: "SyncBN"
+  ROI_HEADS:
+    NAME: "Res5ROIHeadsExtraNorm"
+    NUM_CLASSES: 1
+    SCORE_THRESH_TEST: 0.01
+    NMS_THRESH_TEST: 0.4
+  BACKBONE:
+    FREEZE_AT: 2
+  ROI_BOX_HEAD:
+    NORM: "SyncBN" # RGB Mean and Std
+  PIXEL_MEAN: [123.675, 116.280, 103.530]
+  PIXEL_STD: [58.395, 57.120, 57.375]
+INPUT:
+  MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
+  MIN_SIZE_TEST: 800
+  FORMAT: "RGB"
+DATASETS:
+  TRAIN: ('coco20k_train_LOST_CAD', )
+  TEST: ('coco20k_train_CAD_gt', )
+TEST:
+  EVAL_PERIOD: 5000
+  PRECISE_BN:
+    ENABLED: True
+SOLVER:
+  STEPS: (18000, 22000)
+  MAX_ITER: 24000
+  WARMUP_ITERS: 100 # Maybe needs tuning.
+  IMS_PER_BATCH: 16
+  BASE_LR: 0.02 # Maybe it will need some tuning. MoCo used 0.02.
+OUTPUT_DIR: "./outputs/RN50_DINO_FRCNN_COCO20k_CAD"
\ No newline at end of file
diff --git a/tools/prepare_coco_CAD_gt.py b/tools/prepare_coco_CAD_gt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8861b327453f6e0d128a2a6817b25521c2c7e1d5
--- /dev/null
+++ b/tools/prepare_coco_CAD_gt.py
@@ -0,0 +1,75 @@
+# Copyright 2021 - Valeo Comfort and Driving Assistance
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import pathlib
+import argparse
+import detectron2.data
+from tqdm import tqdm
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(
+        description="Prepares the CAD gt for COCO20k"
+                    "dataset in the data format expected from detectron2.")
+    parser.add_argument("--coco_dir", type=str, default='../datasets/COCO',
+                        help="Path to where the COCO dataset is.")
+    parser.add_argument("--file_coco20k", type=str, default='../datasets/coco_20k_filenames.txt',
+                        help="Location of COCO20k subset.")
+    args = parser.parse_args()
+
+    print('Prepare Class-Agnostic COCO20k in the data format expected from detectron2.')
+
+    # Load COCO20k images
+    coco_20k_f = '../datasets/coco_20k_filenames.txt'
+    with open(args.file_coco20k, "r") as f:
+        sel_20k = f.readlines()
+        sel_20k = [s.replace("\n", "") for s in sel_20k]
+    im20k = [str(int(s.split("_")[-1].split(".")[0])) for s in sel_20k]
+
+    # Load annotations
+    annotation_file = pathlib.Path(args.coco_dir) / "annotations" / "instances_train2014.json"
+    with open(annotation_file) as json_file:
+        annot = json.load(json_file)
+
+    coco_data_gt_train14 = detectron2.data.DatasetCatalog.get("coco_2014_train")
+    ann_to_img_ids = [x['id'] for ind, x in enumerate(annot['images'])]
+    map_id_to_annot = [x['image_id'] for x in coco_data_gt_train14]
+
+    data_gt_20k = []
+    for file_name in tqdm(sel_20k):
+
+        image_name = file_name[:-len('.jpg')]
+        image_id = image_name.split('_')[-1].split('.')[0]
+        image_id_int = int(image_id)
+        
+        full_img_path = pathlib.Path(args.coco_dir) / "images" / file_name
+        ann_id = ann_to_img_ids.index(image_id_int)
+        assert full_img_path.is_file()
+        annotations = coco_data_gt_train14[map_id_to_annot.index(image_id_int)]["annotations"]
+        ca_annotations = [{'iscrowd':v['iscrowd'], 'bbox':v['bbox'], 'category_id': 0, 'bbox_mode':v['bbox_mode']} for v in annotations]
+
+        data_gt_20k.append({
+            "file_name": str(full_img_path),
+            "image_id": image_id,
+            "height": annot['images'][ann_id]['height'],
+            "width": annot['images'][ann_id]['width'],
+            "annotations": ca_annotations,
+        })
+
+    print("Dataset COCO20k CAD-gt has been saved.")
+
+    json_data = {"dataset": data_gt_20k,}
+    with open(f'./datasets/coco20k_trainval_CAD_gt.json', 'w') as outfile:
+        json.dump(json_data, outfile)
diff --git a/tools/prepare_coco_LOST_CAD_pseudo_boxes_in_detectron2_format.py b/tools/prepare_coco_LOST_CAD_pseudo_boxes_in_detectron2_format.py
new file mode 100644
index 0000000000000000000000000000000000000000..012e7a9568e5e1efed94a78cb245318d3baeb34c
--- /dev/null
+++ b/tools/prepare_coco_LOST_CAD_pseudo_boxes_in_detectron2_format.py
@@ -0,0 +1,110 @@
+# Copyright 2021 - Valeo Comfort and Driving Assistance
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import pickle
+import pathlib
+import argparse
+from tqdm import tqdm
+
+import xml.etree.ElementTree as ET
+from detectron2.structures import BoxMode
+
+def get_img_size(ann_file):
+    # Get the width and height from the annotation file.
+    ann_file = open(ann_file)
+    tree = ET.parse(ann_file)
+    root = tree.getroot()
+    size = root.find('size')
+    width = int(size.find('width').text)
+    height = int(size.find('height').text)
+    return width, height
+
+
+def prepare_annotation_data(loc_object):
+    if not isinstance(loc_object[0], (list, tuple)):
+        loc_object = [loc_object,]
+
+    annotations = []
+    for obj in loc_object:
+        xmin, ymin, xmax, ymax = [float(x) for x in obj]
+        annotations.append({
+            "iscrowd": 0,
+            "bbox": [xmin, ymin, xmax, ymax],
+            "category_id": 0,
+            "bbox_mode": BoxMode.XYXY_ABS})
+
+    return annotations
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(
+        description="Prepares the LOST pseudo-boxes from a COCO2014"
+                    "dataset in the data format expected from detectron2.")
+    parser.add_argument("--coco_dir", type=str, default='../datasets/COCO',
+                        help="Path to where the VOC dataset is.")
+    parser.add_argument("--pboxes", type=str, default='../outputs/COCO20k_train/LOST-vit_small16_k/preds.pkl',
+                        help="Path to where the LOST CA pseudo boxes for the VOCyear trainval data are.")
+    args = parser.parse_args()
+
+    print('Prepare LOST pseudo-boxes (COCO2014) in the data format expected from detectron2.')
+
+    # Load the boxes
+    with open(args.pboxes, 'rb') as handle:
+        LOST_pseudo_boxes = pickle.load(handle)
+
+    annotation_file = pathlib.Path(args.coco_dir) / "annotations" / "instances_train2014.json"
+    with open(annotation_file) as json_file:
+        annot = json.load(json_file)
+
+    data = []
+    cnt = 0
+    for image_name in tqdm(LOST_pseudo_boxes):
+        if 'jpg' in image_name:
+            image_name = image_name[:-len('.jpg')]
+        else:
+            image_name_init = image_name
+            ann_id = [ind for ind, x in enumerate(annot['images']) if x['id'] == int(image_name)][0]
+            image_name = 'train2014/' + annot['images'][ann_id]['file_name']
+
+        image_id = image_name.split('_')[-1].split('.')[0]
+        image_id_int = int(image_id)
+        full_img_path = pathlib.Path(args.coco_dir) / "images" / image_name
+        ann_id = [ind for ind, x in enumerate(annot['images']) if x['id'] == image_id_int][0]
+        assert full_img_path.is_file()
+
+        data.append({
+            "file_name": str(full_img_path),
+            "image_id": image_id,
+            "height": annot['images'][ann_id]['height'], "width": annot['images'][ann_id]['width'],
+            "annotations": prepare_annotation_data(LOST_pseudo_boxes[image_name_init]),
+        })
+        cnt += 1
+
+    print(f'Number images saved {cnt}')
+    dataset_name = f"coco20k_train_LOST_CAD"
+    json_data = {
+        "dataset": data,
+        "meta_data": {
+            "dirname": args.coco_dir,
+            "evaluator_type": "pascal_voc",
+            "name": dataset_name,
+            "split": "train",
+            "year": 2014,
+            "thing_classes": "object",
+        }}
+    dst_file = f'./datasets/{dataset_name}.json'
+    print(f"The pseudo-boxes at {args.pboxes} will be transformed into a detectron2-compatible dataset format at {dst_file}")
+    with open(dst_file, 'w') as outfile:
+        json.dump(json_data, outfile)
diff --git a/tools/train_net_for_LOST_CAD.py b/tools/train_net_for_LOST_CAD.py
index bbac8cf0f8e1eae931439d2f654473e09d48dad1..6aee2102542e716dd8e23e346308b2365a93332f 100755
--- a/tools/train_net_for_LOST_CAD.py
+++ b/tools/train_net_for_LOST_CAD.py
@@ -5,6 +5,7 @@
 
 import logging
 import os
+import copy
 from collections import OrderedDict
 import torch
 
@@ -96,8 +97,91 @@ def register_CAD_LOST_pseudo_boxes_for_the_voc2007_trainval_dataset(
     detectron2.data.MetadataCatalog.get(voc2007_dataset_name).thing_classes = ["object",]
     detectron2.data.MetadataCatalog.get(voc2007_dataset_name).evaluator_type = "coco"
 
+def register_CAD_objects_coco_train_dataset(image_root=None):
+    print(f"Registering the 'coco_train_CAD' for class agnostic object detection.")
+    def coco_train_ca_dataset_function():
+        coco_data_gt = detectron2.data.DatasetCatalog.get("coco_2014_train")
+        coco_data_gt = copy.deepcopy(coco_data_gt)
+        # Make the ground bounding boxes class agnostic (i.e., give to all of
+        # them the category id 0).
+        for i in range(len(coco_data_gt)):
+            if image_root is not None:
+                coco_data_gt[i]["file_name"] = \
+                    coco_data_gt[i]["file_name"].replace('datasets/coco', image_root)
+            for j in range(len(coco_data_gt[i]["annotations"])):
+                coco_data_gt[i]["annotations"][j]["category_id"] = 0
+        return coco_data_gt
+    detectron2.data.DatasetCatalog.register(
+        "coco_train_CAD", coco_train_ca_dataset_function)
+    detectron2.data.MetadataCatalog.get("coco_train_CAD").thing_classes = ["object",]
+    detectron2.data.MetadataCatalog.get("coco_train_CAD").evaluator_type = "coco"
+    detectron2.data.MetadataCatalog.get("coco_train_CAD").name = "coco_train_CAD"
+
+def register_CAD_objects_coco_val_dataset(image_root=None):
+    print(f"Registering the 'coco_val_CAD' for class agnostic object detection.")
+    def coco_val_ca_dataset_function():
+        coco_data_gt = detectron2.data.DatasetCatalog.get("coco_2014_val")
+        coco_data_gt = copy.deepcopy(coco_data_gt)
+        # Make the ground bounding boxes class agnostic (i.e., give to all of
+        # them the category id 0).
+        for i in range(len(coco_data_gt)):
+            if image_root is not None:
+                coco_data_gt[i]["file_name"] = \
+                    coco_data_gt[i]["file_name"].replace('datasets/coco', image_root)
+            for j in range(len(coco_data_gt[i]["annotations"])):
+                coco_data_gt[i]["annotations"][j]["category_id"] = 0
+        return coco_data_gt
+    detectron2.data.DatasetCatalog.register(
+        "coco_val_CAD", coco_val_ca_dataset_function)
+    detectron2.data.MetadataCatalog.get("coco_val_CAD").thing_classes = ["object",]
+    detectron2.data.MetadataCatalog.get("coco_val_CAD").evaluator_type = "coco"
+    detectron2.data.MetadataCatalog.get("coco_val_CAD").name = "coco_val_CAD"
+
+def register_CAD_coco20k_train_gt_dataset(
+    coco_json_path="./datasets/coco20k_trainval_CAD_gt.json",
+    coco_dataset_name="coco20k_train_CAD_gt"):
+
+    print(f"Registering the '{coco_dataset_name}' from the json file {coco_json_path}")
+    def coco_train_dataset_function():
+        with open(coco_json_path) as infile:
+            json_data = json.load(infile)
+        return json_data["dataset"]
+    detectron2.data.DatasetCatalog.register(
+        coco_dataset_name, coco_train_dataset_function)
+    detectron2.data.MetadataCatalog.get(coco_dataset_name).thing_classes = ["object",]
+    detectron2.data.MetadataCatalog.get(coco_dataset_name).evaluator_type = "coco"
+
+def register_CAD_LOST_pseudo_boxes_for_the_coco20k_trainval_dataset(
+    coco20k_json_path="./datasets/coco20k_train_LOST_CAD.json",
+    coco20k_dataset_name="coco20k_train_LOST_CAD"):
+
+    print(f"Registering the '{coco20k_dataset_name}' from the json file {coco20k_json_path}")
+    def coco20k_train_dataset_function():
+        with open(coco20k_json_path) as infile:
+            json_data = json.load(infile)
+        return json_data["dataset"]
+    detectron2.data.DatasetCatalog.register(
+        coco20k_dataset_name, coco20k_train_dataset_function)
+    detectron2.data.MetadataCatalog.get(coco20k_dataset_name).thing_classes = ["object",]
+    detectron2.data.MetadataCatalog.get(coco20k_dataset_name).evaluator_type = "coco"
+
+
+#*******************************************************************************
+#*******************************************************************************
+# Comment out those not needed.
+# Register VOC datasets
 register_voc_in_coco_style()
 register_CAD_LOST_pseudo_boxes_for_the_voc2007_trainval_dataset()
+
+# Register COCO dataset
+register_CAD_coco20k_train_gt_dataset()
+register_CAD_objects_coco_train_dataset(image_root='../datasets/COCO/images')
+register_CAD_objects_coco_val_dataset(image_root='../datasets/COCO/images')
+try:
+    register_CAD_LOST_pseudo_boxes_for_the_coco20k_trainval_dataset()
+except:
+    print("If failing here, please make sure to construct pseudo-boxes dataset using:\
+          >python tools/prepare_coco_LOST_CAD_pseudo_boxes_in_detectron2_format.py --pboxes /path/preds.pkl")
 #*******************************************************************************
 #*******************************************************************************