train_model.py 1.58 KB
Newer Older
Kyle Anderson's avatar
Kyle Anderson committed
1 2 3 4 5 6 7 8 9 10 11
"""
Responsible for training the machine learning model for recognizing faces.
"""


from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC

import common
import data_handler

Kyle Anderson's avatar
Kyle Anderson committed
12
def train_and_save(facial_embeddings_database: str= common.DATABASE_LOC, output_file: str = common.RECOGNITION_DATABASE_LOC) -> None:
Kyle Anderson's avatar
Kyle Anderson committed
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
    """
    Trains the database using the given facial embeddings database and outputs the results to file.
    :param facial_embeddings_database: The facial embedding database location.
    :param output_file: The file location for the output of the database.
    :return: None
    """
    database = data_handler.load_database(facial_embeddings_database)
    data_handler.write_database(output_file, train_model(database))


def train_model(facial_embeddings: dict) -> SVC:
    """
    Trains the model for the given database
    :param facial_embeddings_database: The location of the pickle database.
    :param output_file: File location where to output the pickle database of facial recognitions.
    :return:
    """
    label_encoder = LabelEncoder()
    user_id_repeat_list = []
    for user_id, encodings in facial_embeddings.items():
        user_id_repeat_list.extend([user_id for x in range(len(encodings))])

    # The facial_embeddings
    labels = label_encoder.fit_transform(user_id_repeat_list)

    recognizer = SVC(C=1.0, kernel="linear", probability=True)
    # TODO not too sure this line does what is intended.
    recognizer.fit(data_handler.get_encodings_in_database(facial_embeddings), labels)

    return recognizer

if __name__ == "__main__":
    train_and_save()