From c1f31cfde851af35e0596d67aec50d795a28e855 Mon Sep 17 00:00:00 2001
From: Mingchung Xia <mingchung.xia@gmail.com>
Date: Thu, 25 Jan 2024 17:36:15 -0500
Subject: [PATCH] HNSW Sequence and Collection conformance, document encoding
 and RNG insertions

---
 .../HNSWCorpus + Sequence.swift               | 40 +++++---------
 .../1. Data Collection/HNSWCorpus.swift       | 54 +++++++------------
 2 files changed, 31 insertions(+), 63 deletions(-)

diff --git a/Sources/SwiftNLP/1. Data Collection/HNSWCorpus + Sequence.swift b/Sources/SwiftNLP/1. Data Collection/HNSWCorpus + Sequence.swift
index 1bbbd959..11e901fb 100644
--- a/Sources/SwiftNLP/1. Data Collection/HNSWCorpus + Sequence.swift	
+++ b/Sources/SwiftNLP/1. Data Collection/HNSWCorpus + Sequence.swift	
@@ -21,56 +21,42 @@
 // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
 // OTHER DEALINGS IN THE SOFTWARE.
 
-// MARK: Sequence conformance will be done when HNSWCorpus is complete
-
  
-extension HNSWCorpus: Sequence {
+extension HNSWCorpus: Sequence, Collection {
         
     typealias Element = [Scalar]
     
     // Sequence Protocol Requirements
-    @inlinable 
+    @inlinable
     func makeIterator() -> AnyIterator<Element> {
-        /// DeterministicSampleVectorIndex likely does not have .count or .elementAt implemented
-        /// This provides a boilerplate template for the protocol conformance
-        
         var index = 0
         return AnyIterator {
-            guard index < self.encodedDocuments.count else { return nil } // .count
-            let element = self.encodedDocuments.elementAt(index) // .elementAt
-            index += 1
+            defer { index += 1 }
+            guard index < self.encodedDocuments.base.vectors.count else { return nil }
+            let element = self.encodedDocuments.base.vectors[index] // consider using .find
             return element
         }
     }
     
-    /*
-    // Sequence Protocol Requirements
-    @inlinable
-    func makeIterator() -> Dictionary<Int, [Scalar]>.Values.Iterator {
-        return encodedDocuments.values.makeIterator()
-    }
-    
-    
     // Collection Protocol Requirements
     @inlinable
-    var startIndex: Dictionary<Int, [Scalar]>.Index {
-        return encodedDocuments.startIndex
+    var startIndex: Int {
+        return encodedDocuments.base.vectors.startIndex
     }
     
     @inlinable
-    var endIndex: Dictionary<Int, [Scalar]>.Index {
-        return encodedDocuments.endIndex
+    var endIndex: Int {
+        return encodedDocuments.base.vectors.endIndex
     }
     
     @inlinable
-    subscript(position: Dictionary<Int, [Scalar]>.Index) -> [Scalar] {
-        encodedDocuments.values[position]
+    subscript(position: Int) -> Element {
+        return encodedDocuments.base.vectors[position]
     }
     
     @inlinable
-    func index(after i: Dictionary<Int, [Scalar]>.Index) -> Dictionary<Int, [Scalar]>.Index {
-        return encodedDocuments.index(after: i)
+    func index(after i: Int) -> Int {
+        return encodedDocuments.base.vectors.index(after: i)
     }
-    */
 }
 
diff --git a/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift b/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift
index c59a8aa6..a30c397e 100644
--- a/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift	
+++ b/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift	
@@ -24,7 +24,6 @@
 import Foundation
 import PriorityHeapModule
 import PriorityHeapAlgorithms
-
 import SimilarityMetric
 import HNSWAlgorithm
 import HNSWEphemeral
@@ -33,13 +32,11 @@ import GameplayKit
 
 class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
             
-    var _documentEncoder: ContextFreeEncoder<Scalar>
+    private var _documentEncoder: ContextFreeEncoder<Scalar>
     var zeroes: [Scalar]
     var count: Int { 0 }
-    
-    // var encodedDocuments: [Int : [Scalar]] = [:]
 
-    /// typicalNeighbourhoodSize is unknown
+    /// typicalNeighbourhoodSize = 20 is a standard benchmark
     var encodedDocuments: DeterministicSampleVectorIndex = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: 20)
     
     init(_documentEncoder: ContextFreeEncoder<Scalar>) {
@@ -49,67 +46,53 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
     
     @inlinable
     func addUntokenizedDocument(_ document: String) {
-        let encodedDocument = _documentEncoder.encodeSentence(document)
-        encodedDocuments.insert(vector: encodedDocument)
+        encodedDocuments.insert((_documentEncoder.encodeSentence(document) as! [Scalar]))
     }
-
-    // HNSW indexes do not support deletion - index must be rebuilt
 }
 
 
 public struct DeterministicSampleVectorIndex<Vector: Collection & Codable> where Vector.Element: BinaryFloatingPoint {
     
-    // MARK: Index accepts only [Double]
-    public typealias Index = EphemeralVectorIndex<Int, Int, CartesianDistanceMetric<[Double]>, Void>
+    /// EmphermalVectorIndex<Key: BinaryInteger, Level: BinaryInteger, Metric: SimilarityMetric, Metadata>
+    public typealias Index = EphemeralVectorIndex<Int, Int, CartesianDistanceMetric<Vector>, Void>
     public var base: Index
     
     public init(typicalNeighborhoodSize: Int) {
-        base = .init(metric: .init(), config: .unstableDefault(typicalNeighborhoodSize: typicalNeighborhoodSize))
+        base = .init(metric: CartesianDistanceMetric<Vector>(), config: .unstableDefault(typicalNeighborhoodSize: typicalNeighborhoodSize))
     }
     
-    // private var vectorRNG = DeterministicRandomNumberGenerator(seed: 0)
-    // private var graphRNG = DeterministicRandomNumberGenerator(seed: 1)
+    private var graphRNG = DeterministicRandomNumberGenerator(seed: 1)
     
     public func find(near query: Vector, limit: Int, exact: Bool = false) throws -> [Index.Neighbor] {
         if exact {
-            // Exact search
-            Array(PriorityHeap(base.vectors.enumerated().map {
-                let similarity = base.metric.similarity(between: query as! [Double], $0.element)
+            return Array(PriorityHeap(base.vectors.enumerated().map {
+                let similarity = base.metric.similarity(between: query, $0.element)
                 return NearbyVector(id: $0.offset, vector: $0.element, priority: similarity)
             }).descending().prefix(limit))
         } else {
-            // Approximation search
-            Array(try base.find(near: query as! [Double], limit: limit))
+            return Array(try base.find(near: query, limit: limit))
         }
     }
     
-    public mutating func insert(vector: Vector) {
-        // Convert the generic vector to a [Double], which is the required type for `base.insert`
-        let doubleVector = vector.map { Double($0) }
-
-        // Insert the vector into the HNSW graph
-        base.insert(doubleVector)
-    }
-
-    /*
-    public mutating func insertRandom(range: ClosedRange<Double>) {
-        base.insert(generateRandom(range: range) as! [Double], using: &graphRNG)
+    public mutating func insert(_ vector: Vector) {
+        let convertedVector: [Double] = vector.map{ Double($0) }
+        if let metricVector = convertedVector as? CartesianDistanceMetric<Vector>.Vector {
+            base.insert(metricVector, using: &graphRNG)
+        } else {
+            fatalError("Unable to get metric vector")
+        }
     }
-    */
 }
 
-public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint{
+public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint {
     public func similarity(between someItem: Vector, _ otherItem: Vector) -> Vector.Element {
-        // Naïve cartesian distance
         let squaredSum = zip(someItem, otherItem)
                 .map { (x, y) in (x - y) * (x - y) }
                 .reduce(0, +)
-        
         return sqrt(squaredSum)
     }
 }
 
-/*
 struct DeterministicRandomNumberGenerator: RandomNumberGenerator {
     private let randomSource: GKMersenneTwisterRandomSource
 
@@ -123,4 +106,3 @@ struct DeterministicRandomNumberGenerator: RandomNumberGenerator {
         return upperBits | lowerBits
     }
 }
-*/
-- 
GitLab