diff --git a/Sources/SwiftNLP/1. Data Collection/EphemeralHNSWCorpus + Sequence.swift b/Sources/SwiftNLP/1. Data Collection/EphemeralHNSWCorpus + Sequence.swift index 515db3b045fddc38859394e083d72fe7d9b3be78..0f3f4c68ee92c13805da3f5458adaca3e2c9eb98 100644 --- a/Sources/SwiftNLP/1. Data Collection/EphemeralHNSWCorpus + Sequence.swift +++ b/Sources/SwiftNLP/1. Data Collection/EphemeralHNSWCorpus + Sequence.swift @@ -21,43 +21,43 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR // OTHER DEALINGS IN THE SOFTWARE. +import Foundation + /// HNSWCorpus iterates through its dictionary of key to document vector pairs extension EphemeralHNSWCorpus: Sequence, Collection { - - typealias Element = [Scalar] - // Sequence Protocol Requirements @inlinable - func makeIterator() -> AnyIterator<Element> { - var index = 0 + func makeIterator() -> AnyIterator<DocumentVectorPair> { + var iterator = dictionary.values.makeIterator() return AnyIterator { - defer { index += 1 } - guard index < self.encodedDocuments.base.vectors.count else { return nil } - let element = self.encodedDocuments.base.vectors[index] // consider using .find - return element + return iterator.next() } } // Collection Protocol Requirements @inlinable var startIndex: Int { - return encodedDocuments.base.vectors.startIndex + return dictionary.keys.sorted().startIndex } @inlinable var endIndex: Int { - return encodedDocuments.base.vectors.endIndex + return dictionary.keys.sorted().endIndex } @inlinable - subscript(position: Int) -> Element { - return encodedDocuments.base.vectors[position] + subscript(position: Int) -> DocumentVectorPair { + let key = dictionary.keys.sorted()[position] + guard let pair = dictionary[key] else { + fatalError("Key \(key) not found in HNSW dictionary") + } + return pair } @inlinable func index(after i: Int) -> Int { - return encodedDocuments.base.vectors.index(after: i) + return dictionary.keys.sorted().index(after: i) } } diff --git a/Sources/SwiftNLP/1. Data Collection/HNSW/HNSWCorpusDataHandler.swift b/Sources/SwiftNLP/1. Data Collection/HNSW/HNSWCorpusDataHandler.swift index 88e95ff0437d0f1b62ca074f94a804f22216adfd..a6af8574d1cfeae396d260a8cfca0e3efb919c18 100644 --- a/Sources/SwiftNLP/1. Data Collection/HNSW/HNSWCorpusDataHandler.swift +++ b/Sources/SwiftNLP/1. Data Collection/HNSW/HNSWCorpusDataHandler.swift @@ -5,6 +5,8 @@ // Created by Mingchung Xia on 2024-02-13. // +// This is outdated since we now have the presence of a DurableHNSWCorpus but still available for reference + import Foundation final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> { @@ -13,7 +15,6 @@ final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> { init(corpus: HNSWCorpus<Scalar>, resource: String = "hnsw") { self.corpus = corpus - // TODO: Try to fix this to work in the Bundle (does not write but can read) // self.url = Bundle.module.url(forResource: resource, withExtension: "mmap") if let downloadsDirectory = FileManager.default.urls(for: .downloadsDirectory, in: .userDomainMask).first { self.url = downloadsDirectory.appendingPathComponent(resource + ".mmap") @@ -45,9 +46,9 @@ final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> { return size } -// private func heapSize(_ obj: AnyObject) -> Int { -// return malloc_size(Unmanaged.passUnretained(obj).toOpaque()) -// } + private func heapSize(_ obj: AnyObject) -> Int { + return malloc_size(Unmanaged.passUnretained(obj).toOpaque()) + } } extension HNSWCorpusDataHandler { @@ -66,14 +67,13 @@ extension HNSWCorpusDataHandler { // let count = corpus.count // let countData = withUnsafeBytes(of: count) { Data($0) } // fileHandle.write(countData) -// -// // TODO: We may need to edit the HNSWCorpus iterator to actually iterate over its dictionary as it would be useful here -// let data = corpus.getDictionary() -// for (key, documentVectorPair) in data { -// let documentData = documentVectorPair.untokenizedDocument.utf8CString.withUnsafeBufferPointer { Data(buffer: $0) } +// +// for pair in corpus { +// let documentData = pair.untokenizedDocument.utf8CString.withUnsafeBufferPointer { Data(buffer: $0) } // fileHandle.write(documentData) // } // fileHandle.closeFile() + print("Saving HNSW to file...") /// Using the Codable conformances let encoder = JSONEncoder() @@ -132,8 +132,4 @@ extension HNSWCorpusDataHandler { let encoder = ContextFreeEncoder<Scalar>(source: encoding) return loadMemoryMap(encoder: encoder, typicalNeighborhoodSize: typicalNeighborhoodSize, resource: resource) } - - static func loadDictionaryMemoryMap() { - // TODO: Move from DurableHNSW extension once HNSW wrapper is created - } } diff --git a/Tests/SwiftNLPTests/1. Data Collection/HNSW/DurableHNSWCorpusTests.swift b/Tests/SwiftNLPTests/1. Data Collection/HNSW/DurableHNSWCorpusTests.swift index 4360f27a0f1ce31d9cdc4228dd0902d1c59c9e83..56be08dbcce986a1c401dd3dd743a692d7c06670 100644 --- a/Tests/SwiftNLPTests/1. Data Collection/HNSW/DurableHNSWCorpusTests.swift +++ b/Tests/SwiftNLPTests/1. Data Collection/HNSW/DurableHNSWCorpusTests.swift @@ -5,7 +5,16 @@ import CoreLMDB import System @testable import SwiftNLP +// MARK: These tests are not to be included within the pipeline + final class DurableHNSWCorpusTests: XCTestCase { + /// This is used to skip these tests in the GitLab pipeline + override class var defaultTestSuite: XCTestSuite { + if ProcessInfo.processInfo.environment["SKIP_TESTS"] == "DurableHNSWCorpusTests" { + return XCTestSuite(name: "Empty") + } + return super.defaultTestSuite + } /// Setting up constants for environment private let ONE_GB: Int = 1_073_741_824 diff --git a/Tests/SwiftNLPTests/1. Data Collection/HNSW/EphemeralHNSWCorpusTests.swift b/Tests/SwiftNLPTests/1. Data Collection/HNSW/EphemeralHNSWCorpusTests.swift index a4594629ef2da9a54f3599ba9e5ce0f49bab7908..d6b39798ae1ef1700946895b01a988c8b59e8a4a 100644 --- a/Tests/SwiftNLPTests/1. Data Collection/HNSW/EphemeralHNSWCorpusTests.swift +++ b/Tests/SwiftNLPTests/1. Data Collection/HNSW/EphemeralHNSWCorpusTests.swift @@ -24,8 +24,8 @@ final class EphemeralHNSWCorpusTests: XCTestCase { XCTAssert(corpus.count == 3) /// Make sure none of our encodings are zero - for c in corpus { - XCTAssertNotEqual(c, corpus.zeroes) + for item in corpus { + XCTAssertNotEqual(item.vector, corpus.zeroes) } } @@ -60,8 +60,8 @@ final class EphemeralHNSWCorpusTests: XCTestCase { XCTAssertEqual(corpus.count, 20) /// Make sure none of our encodings are zero - for c in corpus { - XCTAssertNotEqual(c, corpus.zeroes) + for item in corpus { + XCTAssertNotEqual(item.vector, corpus.zeroes) } }