diff --git a/Sources/SwiftNLP/1. Data Collection/HNSWCorpusDataHandler.swift b/Sources/SwiftNLP/1. Data Collection/HNSWCorpusDataHandler.swift index 7035dbf62011a4a604a1b7f66bb45869e43e5288..54b82701f4f4f76756f1984852828136863a4a00 100644 --- a/Sources/SwiftNLP/1. Data Collection/HNSWCorpusDataHandler.swift +++ b/Sources/SwiftNLP/1. Data Collection/HNSWCorpusDataHandler.swift @@ -8,8 +8,7 @@ import Foundation final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> { - private var corpus: HNSWCorpus<Scalar> - private var mmapURL: URL? // set default URL + var corpus: HNSWCorpus<Scalar> init(corpus: HNSWCorpus<Scalar>) { self.corpus = corpus @@ -17,11 +16,55 @@ final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> { } extension HNSWCorpusDataHandler { - func saveMemoryMap() { - + func saveMemoryMap(url: URL) { + let fileManager = FileManager.default + if !fileManager.fileExists(atPath: url.path) { + fileManager.createFile(atPath: url.path, contents: nil, attributes: nil) + } + do { + let fileHandle = try FileHandle(forWritingTo: url) + + let count = corpus.count + let countData = withUnsafeBytes(of: count) { Data($0) } + fileHandle.write(countData) + + // TODO: We may need to edit the HNSWCorpus iterator to actually iterate over its dictionary as it would be useful here + let data = corpus.getDictionary() + for (key, documentVectorPair) in data { + let documentData = documentVectorPair.untokenizedDocument.utf8CString.withUnsafeBufferPointer { Data(buffer: $0) } + fileHandle.write(documentData) + } + fileHandle.closeFile() + } catch { + print("Error writing HNSW to file: \(error)") + } } - func loadMemoryMap() { + // TODO: Change the return from Double to Scalar + func loadMemoryMap(url: URL, encoder: any SNLPEncoder) -> HNSWCorpus<Double> { + var loadedCorpus = HNSWCorpus(encoder: encoder) + do { + let data = try Data(contentsOf: url, options: .alwaysMapped) + let countData = data.prefix(MemoryLayout<Int>.size) + let count: Int = countData.withUnsafeBytes { $0.load(as: Int.self) } + var index = MemoryLayout<Int>.size + + for _ in 0..<count { + if let stringRange = data[index...].range(of: "\0".data(using: .utf8)!) { + let documentData = data[index..<stringRange.lowerBound] + if let document = String(data: documentData, encoding: .utf8) { + // Add the untokenized document to the corpus + loadedCorpus.addUntokenizedDocument(document) + index = stringRange.upperBound + } + } else { + break + } + } + } catch { + print("Error reading HNSW from file: \(error)") + } + return loadedCorpus } }