Skip to content
Snippets Groups Projects
Commit e34f6e70 authored by Mingchung Xia's avatar Mingchung Xia
Browse files

Moved dictionary operations to separate file

parent a0dd924c
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
//
// File.swift
//
// HNSWCorpus + Dictionary.swift
//
//
// Created by Mingchung Xia on 2024-02-14.
//
......@@ -42,7 +42,10 @@ extension HNSWCorpus {
return dictionary
}
private func addDocumentVectorPair(at key: Int, document: String, vector: [Scalar]) {
dictionary[key] = DocumentVectorPair(untokenizedDocument: document, vector: vector)
func addDocumentVectorPair(at key: Int, document: String, vector: [Scalar]) {
dictionary[key] = DocumentVectorPair(
untokenizedDocument: document,
vector: vector
)
}
}
......@@ -25,7 +25,6 @@ import Foundation
class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
// internal var _documentEncoder: ContextFreeEncoder<Scalar>
internal var _documentEncoder: any SNLPEncoder
var zeroes: [Scalar] { _documentEncoder.zeroes as! [Scalar] }
......@@ -33,7 +32,7 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
var count: Int { encodedDocuments.base.vectors.count }
// Keeps track of the original document for client code
private var dictionary: [Int: DocumentVectorPair] = [:]
var dictionary: [Int: DocumentVectorPair] = [:]
// typicalNeighbourhoodSize = 20 is a standard benchmark
init(encoding: ContextFreeEncoder<Scalar>.PreComputedEmbeddings, scalar: Scalar.Type = Double.self,
......@@ -46,12 +45,6 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
_documentEncoder = encoder
encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
}
// init(_documentEncoder: ContextFreeEncoder<Scalar>, typicalNeighborhoodSize: Int = 20) {
// self._documentEncoder = _documentEncoder
// zeroes = Array(repeating: Scalar(0), count: 384)
// encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
// }
// Decodable conformance
// required init(from decoder: Decoder) throws {
......@@ -72,44 +65,4 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
vector: encodedDocuments.base.vectors[key]
)
}
// @inlinable
// func addVector(_ vector: [Scalar]) {
// encodedDocuments.insert(vector)
// }
}
extension HNSWCorpus {
/// This extension is used for the dictionary operations
private struct DocumentVectorPair {
var untokenizedDocument: String
var vector: [Scalar]
init(untokenizedDocument: String, vector: [Scalar]) {
self.untokenizedDocument = untokenizedDocument
self.vector = vector
}
}
@inlinable
func getUntokenizedDocument(at key: Int) -> String {
if let pair = dictionary[key] {
return pair.untokenizedDocument
} else {
fatalError("Key \(key) not found in HNSW dictionary")
}
}
@inlinable
func getVector(at key: Int) -> [Scalar] {
if let pair = dictionary[key] {
return pair.vector
} else {
fatalError("Key \(key) not found in HNSW dictionary")
}
}
private func addDocumentVectorPair(at key: Int, document: String, vector: [Scalar]) {
dictionary[key] = DocumentVectorPair(untokenizedDocument: document, vector: vector)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment