From f95be83297e24d3e8d917dcadcf1e1c9eca1f90d Mon Sep 17 00:00:00 2001 From: Mingchung Xia <mingchung.xia@gmail.com> Date: Wed, 7 Feb 2024 21:15:43 -0500 Subject: [PATCH] ContextFreeEncoder for hnsw --- .../ContextFreeEncoder + File IO .swift | 64 +++++++++++++++++++ .../2. Encoding/ContextFreeEncoder.swift | 7 ++ .../SwiftNLPTests/2. Encoding/HNSWTests.swift | 9 ++- 3 files changed, 77 insertions(+), 3 deletions(-) diff --git a/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder + File IO .swift b/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder + File IO .swift index cd5b35ef..58ea84c6 100644 --- a/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder + File IO .swift +++ b/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder + File IO .swift @@ -134,4 +134,68 @@ extension ContextFreeEncoder { return nil } } + + static func writeHNSWToFile(url: URL, hnsw: DeterministicSampleVectorIndex<[Scalar]>) { + let fileManager = FileManager.default + if !fileManager.fileExists(atPath: url.path) { + fileManager.createFile(atPath: url.path, contents: nil, attributes: nil) + } + + do { + let fileHandle = try FileHandle(forWritingTo: url) + + let count = hnsw.base.vectors.count + let countData = withUnsafeBytes(of: count) { Data($0) } + fileHandle.write(countData) + + for vector in hnsw.base.vectors { + let vectorData = vector.map { Float($0) }.withUnsafeBufferPointer { Data(buffer: $0) } + fileHandle.write(vectorData) + } + + fileHandle.closeFile() + } catch { + print("Error writing HNSW to file: \(error)") + } + } + + static func readHNSWFromFile(_ url: URL) -> DeterministicSampleVectorIndex<[Scalar]> { + do { + let data = try Data(contentsOf: url, options: .alwaysMapped) + let countData = data.prefix(MemoryLayout<Int>.size) + let count: Int = countData.withUnsafeBytes { $0.load(as: Int.self) } + var index = MemoryLayout<Int>.size + + var hnsw = DeterministicSampleVectorIndex<[Scalar]>() + + for _ in 0..<count { + let vectorData = data[index..<(index + 50 * MemoryLayout<Scalar>.size)] + let vector = vectorData.withUnsafeBytes { Array($0.bindMemory(to: Scalar.self)) } + hnsw.insert(vector) + index += 50 * MemoryLayout<Scalar>.size + } + + return hnsw + } catch { + print("Error reading HNSW from file: \(error)") + } + return DeterministicSampleVectorIndex<[Scalar]>() + } + + static func readHNSWFromTextFile(from url: URL) -> DeterministicSampleVectorIndex<[Scalar]>? { + do { + let content = try String(contentsOf: url, encoding: .utf8) + let lines = content.split(separator: "\n") + var hnsw = DeterministicSampleVectorIndex<[Scalar]>() + for line in lines { + let vector = line.split(separator: " ").compactMap { Scalar(Double($0)!) } + hnsw.insert(vector) + } + + return hnsw + } catch { + print("Error loading vectors from text file: \(error)") + return nil + } + } } diff --git a/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder.swift b/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder.swift index b90ca71a..653716d1 100644 --- a/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder.swift +++ b/Sources/SwiftNLP/2. Encoding/ContextFreeEncoder.swift @@ -26,6 +26,7 @@ import Foundation class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { var dictionary: [String : [Scalar]] + var hnsw: DeterministicSampleVectorIndex<[Scalar]> let width: Int var zeroes: [Scalar] @@ -34,12 +35,14 @@ class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { public enum PreComputedEmbeddings { case glove6B50d case glove6B100d + case hnswindex //case NLEmbedding } init(source: PreComputedEmbeddings) { dictionary = Dictionary<String,[Scalar]>() + hnsw = DeterministicSampleVectorIndex<[Scalar]>() var dictionaryToLoad: String switch source { @@ -50,6 +53,10 @@ class ContextFreeEncoder<Scalar: BinaryFloatingPoint & Codable>: SNLPEncoder { case .glove6B100d: width = 100 dictionaryToLoad = "glove.6B.100d" + + case .hnswindex: + width = 50 // double check the dimension for HNSW + dictionaryToLoad = "hnswindex" } zeroes = Array(repeating: Scalar(0), count: width) as! [Scalar] diff --git a/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift b/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift index 0813d95e..a679db2b 100644 --- a/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift +++ b/Tests/SwiftNLPTests/2. Encoding/HNSWTests.swift @@ -14,7 +14,8 @@ final class HNSWTests: XCTestCase { "that enable us to train deep learning algorithms to learn like the human brain." ] - let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) + // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) + let encoder = ContextFreeEncoder<Double>(source: .hnswindex) var corpus = HNSWCorpus(_documentEncoder: encoder) corpus.addUntokenizedDocuments(docs) @@ -52,7 +53,8 @@ final class HNSWTests: XCTestCase { "All science is either physics or stamp collecting. - Ernest Rutherford" ] - let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) + // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) + let encoder = ContextFreeEncoder<Double>(source: .hnswindex) var corpus = HNSWCorpus(_documentEncoder: encoder) corpus.addUntokenizedDocuments(twentyQuotes) @@ -76,7 +78,8 @@ final class HNSWTests: XCTestCase { let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData) - let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) + // let encoder = ContextFreeEncoder<Double>(source: .glove6B50d) + let encoder = ContextFreeEncoder<Double>(source: .hnswindex) var corpus = HNSWCorpus(_documentEncoder: encoder) for submission in submissions { -- GitLab