From c0327af8cfb4550f81ef908284b12fe9b9cf6fbe Mon Sep 17 00:00:00 2001
From: Mingchung Xia <mingchung.xia@gmail.com>
Date: Wed, 7 Feb 2024 22:41:04 -0500
Subject: [PATCH] Changed initializers of HNSWCorpus to match dictionary

---
 .../1. Data Collection/HNSWCorpus.swift       | 27 ++++++++++++++-----
 .../SwiftNLPTests/2. Encoding/HNSWTests.swift | 15 ++++++-----
 2 files changed, 30 insertions(+), 12 deletions(-)

diff --git a/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift b/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift
index 0bd97e71..eb2a2b61 100644
--- a/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift	
+++ b/Sources/SwiftNLP/1. Data Collection/HNSWCorpus.swift	
@@ -25,19 +25,33 @@ import Foundation
 
 class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
             
-    internal var _documentEncoder: ContextFreeEncoder<Scalar>
+    // internal var _documentEncoder: ContextFreeEncoder<Scalar>
+    internal var _documentEncoder: any SNLPEncoder
     var zeroes: [Scalar]
     var count: Int { 0 }
 
     // typicalNeighbourhoodSize = 20 is a standard benchmark
     var encodedDocuments: DeterministicSampleVectorIndex<[Scalar]>
-
-    init(_documentEncoder: ContextFreeEncoder<Scalar>, typicalNeighborhoodSize: Int = 20) {
-        self._documentEncoder = _documentEncoder
-        zeroes = Array(repeating: Scalar(0), count: 384)
+    
+    init(encoding: ContextFreeEncoder<Scalar>.PreComputedEmbeddings, scalar: Scalar.Type = Double.self,
+         typicalNeighborhoodSize: Int = 20) {
+        _documentEncoder = ContextFreeEncoder(source: encoding)
+        zeroes = _documentEncoder.zeroes as! [Scalar]
         encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
     }
     
+    init(encoder: any SNLPEncoder, scalar: Scalar.Type = Double.self, typicalNeighborhoodSize: Int = 20) {
+        _documentEncoder = encoder
+        zeroes = _documentEncoder.zeroes as! [Scalar]
+        encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
+    }
+
+//    init(_documentEncoder: ContextFreeEncoder<Scalar>, typicalNeighborhoodSize: Int = 20) {
+//        self._documentEncoder = _documentEncoder
+//        zeroes = Array(repeating: Scalar(0), count: 384)
+//        encodedDocuments = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
+//    }
+    
     // Decodable conformance
 //    required init(from decoder: Decoder) throws {
 //        let container = try decoder.container(keyedBy: CodingKeys.self)
@@ -48,7 +62,8 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
     
     @inlinable
     func addUntokenizedDocument(_ document: String) {
-        encodedDocuments.insert((_documentEncoder.encodeSentence(document))) /// as! [Scalar] not needed
+        /// forced unwrap as! [Scalar] is needed when we use SNLPEncoder and not ContextFreeEncoder
+        encodedDocuments.insert((_documentEncoder.encodeSentence(document)) as! [Scalar])
     }
     
 //    @inlinable
diff --git a/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift b/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift
index a679db2b..5cc736b6 100644
--- a/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift	
+++ b/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift	
@@ -15,8 +15,9 @@ final class HNSWTests: XCTestCase {
          ]
         
         // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
-        let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
-        var corpus = HNSWCorpus(_documentEncoder: encoder)
+        // let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
+        // var corpus = HNSWCorpus(_documentEncoder: encoder)
+        var corpus = HNSWCorpus(encoding: .glove6B50d)
         corpus.addUntokenizedDocuments(docs)
         
         XCTAssert(corpus.encodedDocuments.base.vectors.count == 3)
@@ -54,8 +55,9 @@ final class HNSWTests: XCTestCase {
         ]
         
         // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
-        let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
-        var corpus = HNSWCorpus(_documentEncoder: encoder)
+        // let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
+        // var corpus = HNSWCorpus(_documentEncoder: encoder)
+        var corpus = HNSWCorpus(encoding: .glove6B50d)
         corpus.addUntokenizedDocuments(twentyQuotes)
         
         
@@ -79,8 +81,9 @@ final class HNSWTests: XCTestCase {
         let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
         
         // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
-        let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
-        var corpus = HNSWCorpus(_documentEncoder: encoder)
+        // let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
+        // var corpus = HNSWCorpus(_documentEncoder: encoder)
+        var corpus = HNSWCorpus(encoding: .glove6B50d)
         
         for submission in submissions {
             if let text = submission.selftext {
-- 
GitLab