Skip to content
Snippets Groups Projects
Commit 20d4ebcb authored by Mingchung Xia's avatar Mingchung Xia
Browse files

Ephemeral sequence conformance and cleanup

parent 645b982d
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
Pipeline #114211 passed with warnings
...@@ -21,43 +21,43 @@ ...@@ -21,43 +21,43 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE. // OTHER DEALINGS IN THE SOFTWARE.
import Foundation
/// HNSWCorpus iterates through its dictionary of key to document vector pairs /// HNSWCorpus iterates through its dictionary of key to document vector pairs
extension EphemeralHNSWCorpus: Sequence, Collection { extension EphemeralHNSWCorpus: Sequence, Collection {
typealias Element = [Scalar]
// Sequence Protocol Requirements // Sequence Protocol Requirements
@inlinable @inlinable
func makeIterator() -> AnyIterator<Element> { func makeIterator() -> AnyIterator<DocumentVectorPair> {
var index = 0 var iterator = dictionary.values.makeIterator()
return AnyIterator { return AnyIterator {
defer { index += 1 } return iterator.next()
guard index < self.encodedDocuments.base.vectors.count else { return nil }
let element = self.encodedDocuments.base.vectors[index] // consider using .find
return element
} }
} }
// Collection Protocol Requirements // Collection Protocol Requirements
@inlinable @inlinable
var startIndex: Int { var startIndex: Int {
return encodedDocuments.base.vectors.startIndex return dictionary.keys.sorted().startIndex
} }
@inlinable @inlinable
var endIndex: Int { var endIndex: Int {
return encodedDocuments.base.vectors.endIndex return dictionary.keys.sorted().endIndex
} }
@inlinable @inlinable
subscript(position: Int) -> Element { subscript(position: Int) -> DocumentVectorPair {
return encodedDocuments.base.vectors[position] let key = dictionary.keys.sorted()[position]
guard let pair = dictionary[key] else {
fatalError("Key \(key) not found in HNSW dictionary")
}
return pair
} }
@inlinable @inlinable
func index(after i: Int) -> Int { func index(after i: Int) -> Int {
return encodedDocuments.base.vectors.index(after: i) return dictionary.keys.sorted().index(after: i)
} }
} }
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
// Created by Mingchung Xia on 2024-02-13. // Created by Mingchung Xia on 2024-02-13.
// //
// This is outdated since we now have the presence of a DurableHNSWCorpus but still available for reference
import Foundation import Foundation
final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> { final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> {
...@@ -13,7 +15,6 @@ final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> { ...@@ -13,7 +15,6 @@ final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> {
init(corpus: HNSWCorpus<Scalar>, resource: String = "hnsw") { init(corpus: HNSWCorpus<Scalar>, resource: String = "hnsw") {
self.corpus = corpus self.corpus = corpus
// TODO: Try to fix this to work in the Bundle (does not write but can read)
// self.url = Bundle.module.url(forResource: resource, withExtension: "mmap") // self.url = Bundle.module.url(forResource: resource, withExtension: "mmap")
if let downloadsDirectory = FileManager.default.urls(for: .downloadsDirectory, in: .userDomainMask).first { if let downloadsDirectory = FileManager.default.urls(for: .downloadsDirectory, in: .userDomainMask).first {
self.url = downloadsDirectory.appendingPathComponent(resource + ".mmap") self.url = downloadsDirectory.appendingPathComponent(resource + ".mmap")
...@@ -45,9 +46,9 @@ final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> { ...@@ -45,9 +46,9 @@ final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> {
return size return size
} }
// private func heapSize(_ obj: AnyObject) -> Int { private func heapSize(_ obj: AnyObject) -> Int {
// return malloc_size(Unmanaged.passUnretained(obj).toOpaque()) return malloc_size(Unmanaged.passUnretained(obj).toOpaque())
// } }
} }
extension HNSWCorpusDataHandler { extension HNSWCorpusDataHandler {
...@@ -66,14 +67,13 @@ extension HNSWCorpusDataHandler { ...@@ -66,14 +67,13 @@ extension HNSWCorpusDataHandler {
// let count = corpus.count // let count = corpus.count
// let countData = withUnsafeBytes(of: count) { Data($0) } // let countData = withUnsafeBytes(of: count) { Data($0) }
// fileHandle.write(countData) // fileHandle.write(countData)
// //
// // TODO: We may need to edit the HNSWCorpus iterator to actually iterate over its dictionary as it would be useful here // for pair in corpus {
// let data = corpus.getDictionary() // let documentData = pair.untokenizedDocument.utf8CString.withUnsafeBufferPointer { Data(buffer: $0) }
// for (key, documentVectorPair) in data {
// let documentData = documentVectorPair.untokenizedDocument.utf8CString.withUnsafeBufferPointer { Data(buffer: $0) }
// fileHandle.write(documentData) // fileHandle.write(documentData)
// } // }
// fileHandle.closeFile() // fileHandle.closeFile()
print("Saving HNSW to file...") print("Saving HNSW to file...")
/// Using the Codable conformances /// Using the Codable conformances
let encoder = JSONEncoder() let encoder = JSONEncoder()
...@@ -132,8 +132,4 @@ extension HNSWCorpusDataHandler { ...@@ -132,8 +132,4 @@ extension HNSWCorpusDataHandler {
let encoder = ContextFreeEncoder<Scalar>(source: encoding) let encoder = ContextFreeEncoder<Scalar>(source: encoding)
return loadMemoryMap(encoder: encoder, typicalNeighborhoodSize: typicalNeighborhoodSize, resource: resource) return loadMemoryMap(encoder: encoder, typicalNeighborhoodSize: typicalNeighborhoodSize, resource: resource)
} }
static func loadDictionaryMemoryMap() {
// TODO: Move from DurableHNSW extension once HNSW wrapper is created
}
} }
...@@ -5,7 +5,16 @@ import CoreLMDB ...@@ -5,7 +5,16 @@ import CoreLMDB
import System import System
@testable import SwiftNLP @testable import SwiftNLP
// MARK: These tests are not to be included within the pipeline
final class DurableHNSWCorpusTests: XCTestCase { final class DurableHNSWCorpusTests: XCTestCase {
/// This is used to skip these tests in the GitLab pipeline
override class var defaultTestSuite: XCTestSuite {
if ProcessInfo.processInfo.environment["SKIP_TESTS"] == "DurableHNSWCorpusTests" {
return XCTestSuite(name: "Empty")
}
return super.defaultTestSuite
}
/// Setting up constants for environment /// Setting up constants for environment
private let ONE_GB: Int = 1_073_741_824 private let ONE_GB: Int = 1_073_741_824
......
...@@ -24,8 +24,8 @@ final class EphemeralHNSWCorpusTests: XCTestCase { ...@@ -24,8 +24,8 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
XCTAssert(corpus.count == 3) XCTAssert(corpus.count == 3)
/// Make sure none of our encodings are zero /// Make sure none of our encodings are zero
for c in corpus { for item in corpus {
XCTAssertNotEqual(c, corpus.zeroes) XCTAssertNotEqual(item.vector, corpus.zeroes)
} }
} }
...@@ -60,8 +60,8 @@ final class EphemeralHNSWCorpusTests: XCTestCase { ...@@ -60,8 +60,8 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
XCTAssertEqual(corpus.count, 20) XCTAssertEqual(corpus.count, 20)
/// Make sure none of our encodings are zero /// Make sure none of our encodings are zero
for c in corpus { for item in corpus {
XCTAssertNotEqual(c, corpus.zeroes) XCTAssertNotEqual(item.vector, corpus.zeroes)
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment