From f95be83297e24d3e8d917dcadcf1e1c9eca1f90d Mon Sep 17 00:00:00 2001
From: Mingchung Xia <mingchung.xia@gmail.com>
Date: Wed, 7 Feb 2024 21:15:43 -0500
Subject: [PATCH] ContextFreeEncoder for hnsw

---
 .../ContextFreeEncoder + File IO .swift       | 64 +++++++++++++++++++
 .../2. Encoding/ContextFreeEncoder.swift      |  7 ++
 .../SwiftNLPTests/2. Encoding/HNSWTests.swift |  9 ++-
 3 files changed, 77 insertions(+), 3 deletions(-)

diff --git a/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder + File IO .swift b/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder + File IO .swift
index cd5b35ef..58ea84c6 100644
--- a/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder + File IO .swift	
+++ b/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder + File IO .swift	
@@ -134,4 +134,68 @@ extension ContextFreeEncoder {
             return nil
         }
     }
+    
+    static func writeHNSWToFile(url: URL, hnsw: DeterministicSampleVectorIndex<[Scalar]>) {
+        let fileManager = FileManager.default
+        if !fileManager.fileExists(atPath: url.path) {
+            fileManager.createFile(atPath: url.path, contents: nil, attributes: nil)
+        }
+        
+        do {
+            let fileHandle = try FileHandle(forWritingTo: url)
+            
+            let count = hnsw.base.vectors.count
+            let countData = withUnsafeBytes(of: count) { Data($0) }
+            fileHandle.write(countData)
+            
+            for vector in hnsw.base.vectors {
+                let vectorData = vector.map { Float($0) }.withUnsafeBufferPointer { Data(buffer: $0) }
+                fileHandle.write(vectorData)
+            }
+            
+            fileHandle.closeFile()
+        } catch {
+            print("Error writing HNSW to file: \(error)")
+        }
+    }
+    
+    static func readHNSWFromFile(_ url: URL) -> DeterministicSampleVectorIndex<[Scalar]> {
+        do {
+            let data = try Data(contentsOf: url, options: .alwaysMapped)
+            let countData = data.prefix(MemoryLayout<Int>.size)
+            let count: Int = countData.withUnsafeBytes { $0.load(as: Int.self) }
+            var index = MemoryLayout<Int>.size
+            
+            var hnsw = DeterministicSampleVectorIndex<[Scalar]>()
+            
+            for _ in 0..<count {
+                let vectorData = data[index..<(index + 50 * MemoryLayout<Scalar>.size)]
+                let vector = vectorData.withUnsafeBytes { Array($0.bindMemory(to: Scalar.self)) }
+                hnsw.insert(vector)
+                index += 50 * MemoryLayout<Scalar>.size
+            }
+
+            return hnsw
+        } catch {
+            print("Error reading HNSW from file: \(error)")
+        }
+        return DeterministicSampleVectorIndex<[Scalar]>()
+    }
+    
+    static func readHNSWFromTextFile(from url: URL) -> DeterministicSampleVectorIndex<[Scalar]>? {
+        do {
+            let content = try String(contentsOf: url, encoding: .utf8)
+            let lines = content.split(separator: "\n")
+            var hnsw = DeterministicSampleVectorIndex<[Scalar]>()
+            for line in lines {
+                let vector = line.split(separator: " ").compactMap { Scalar(Double($0)!) }
+                hnsw.insert(vector)
+            }
+
+            return hnsw
+        } catch {
+            print("Error loading vectors from text file: \(error)")
+            return nil
+        }
+    }
 }
diff --git a/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder.swift b/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder.swift
index b90ca71a..653716d1 100644
--- a/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder.swift	
+++ b/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder.swift	
@@ -26,6 +26,7 @@ import Foundation
 class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
     
     var dictionary: [String : [Scalar]]
+    var hnsw: DeterministicSampleVectorIndex<[Scalar]>
     let width: Int
     var zeroes: [Scalar]
     
@@ -34,12 +35,14 @@ class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
     public enum PreComputedEmbeddings {
         case glove6B50d
         case glove6B100d
+        case hnswindex
         //case NLEmbedding
     }
     
     init(source: PreComputedEmbeddings) {
         
         dictionary = Dictionary<String,[Scalar]>()
+        hnsw = DeterministicSampleVectorIndex<[Scalar]>()
         
         var dictionaryToLoad: String
         switch source {
@@ -50,6 +53,10 @@ class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder {
         case .glove6B100d:
             width = 100
             dictionaryToLoad = "glove.6B.100d"
+            
+        case .hnswindex:
+            width = 50 // double check the dimension for HNSW
+            dictionaryToLoad = "hnswindex"
         }
         
         zeroes = Array(repeating: Scalar(0), count: width) as! [Scalar]
diff --git a/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift b/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift
index 0813d95e..a679db2b 100644
--- a/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift	
+++ b/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift	
@@ -14,7 +14,8 @@ final class HNSWTests: XCTestCase {
             "that enable us to train deep learning algorithms to learn like the human brain."
          ]
         
-        let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
+        // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
+        let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
         var corpus = HNSWCorpus(_documentEncoder: encoder)
         corpus.addUntokenizedDocuments(docs)
         
@@ -52,7 +53,8 @@ final class HNSWTests: XCTestCase {
             "All science is either physics or stamp collecting. - Ernest Rutherford"
         ]
         
-        let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
+        // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
+        let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
         var corpus = HNSWCorpus(_documentEncoder: encoder)
         corpus.addUntokenizedDocuments(twentyQuotes)
         
@@ -76,7 +78,8 @@ final class HNSWTests: XCTestCase {
         
         let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
         
-        let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
+        // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
+        let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
         var corpus = HNSWCorpus(_documentEncoder: encoder)
         
         for submission in submissions {
-- 
GitLab