Skip to content
Snippets Groups Projects
Commit 20d4ebcb authored by Mingchung Xia's avatar Mingchung Xia
Browse files

Ephemeral sequence conformance and cleanup

parent 645b982d
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
Pipeline #114211 passed with warnings
......@@ -21,43 +21,43 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
import Foundation
/// HNSWCorpus iterates through its dictionary of key to document vector pairs
extension EphemeralHNSWCorpus: Sequence, Collection {
typealias Element = [Scalar]
// Sequence Protocol Requirements
@inlinable
func makeIterator() -> AnyIterator<Element> {
var index = 0
func makeIterator() -> AnyIterator<DocumentVectorPair> {
var iterator = dictionary.values.makeIterator()
return AnyIterator {
defer { index += 1 }
guard index < self.encodedDocuments.base.vectors.count else { return nil }
let element = self.encodedDocuments.base.vectors[index] // consider using .find
return element
return iterator.next()
}
}
// Collection Protocol Requirements
@inlinable
var startIndex: Int {
return encodedDocuments.base.vectors.startIndex
return dictionary.keys.sorted().startIndex
}
@inlinable
var endIndex: Int {
return encodedDocuments.base.vectors.endIndex
return dictionary.keys.sorted().endIndex
}
@inlinable
subscript(position: Int) -> Element {
return encodedDocuments.base.vectors[position]
subscript(position: Int) -> DocumentVectorPair {
let key = dictionary.keys.sorted()[position]
guard let pair = dictionary[key] else {
fatalError("Key \(key) not found in HNSW dictionary")
}
return pair
}
@inlinable
func index(after i: Int) -> Int {
return encodedDocuments.base.vectors.index(after: i)
return dictionary.keys.sorted().index(after: i)
}
}
......@@ -5,6 +5,8 @@
// Created by Mingchung Xia on 2024-02-13.
//
// This is outdated since we now have the presence of a DurableHNSWCorpus but still available for reference
import Foundation
final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> {
......@@ -13,7 +15,6 @@ final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> {
init(corpus: HNSWCorpus<Scalar>, resource: String = "hnsw") {
self.corpus = corpus
// TODO: Try to fix this to work in the Bundle (does not write but can read)
// self.url = Bundle.module.url(forResource: resource, withExtension: "mmap")
if let downloadsDirectory = FileManager.default.urls(for: .downloadsDirectory, in: .userDomainMask).first {
self.url = downloadsDirectory.appendingPathComponent(resource + ".mmap")
......@@ -45,9 +46,9 @@ final class HNSWCorpusDataHandler<Scalar: BinaryFloatingPoint & Codable> {
return size
}
// private func heapSize(_ obj: AnyObject) -> Int {
// return malloc_size(Unmanaged.passUnretained(obj).toOpaque())
// }
private func heapSize(_ obj: AnyObject) -> Int {
return malloc_size(Unmanaged.passUnretained(obj).toOpaque())
}
}
extension HNSWCorpusDataHandler {
......@@ -66,14 +67,13 @@ extension HNSWCorpusDataHandler {
// let count = corpus.count
// let countData = withUnsafeBytes(of: count) { Data($0) }
// fileHandle.write(countData)
//
// // TODO: We may need to edit the HNSWCorpus iterator to actually iterate over its dictionary as it would be useful here
// let data = corpus.getDictionary()
// for (key, documentVectorPair) in data {
// let documentData = documentVectorPair.untokenizedDocument.utf8CString.withUnsafeBufferPointer { Data(buffer: $0) }
//
// for pair in corpus {
// let documentData = pair.untokenizedDocument.utf8CString.withUnsafeBufferPointer { Data(buffer: $0) }
// fileHandle.write(documentData)
// }
// fileHandle.closeFile()
print("Saving HNSW to file...")
/// Using the Codable conformances
let encoder = JSONEncoder()
......@@ -132,8 +132,4 @@ extension HNSWCorpusDataHandler {
let encoder = ContextFreeEncoder<Scalar>(source: encoding)
return loadMemoryMap(encoder: encoder, typicalNeighborhoodSize: typicalNeighborhoodSize, resource: resource)
}
static func loadDictionaryMemoryMap() {
// TODO: Move from DurableHNSW extension once HNSW wrapper is created
}
}
......@@ -5,7 +5,16 @@ import CoreLMDB
import System
@testable import SwiftNLP
// MARK: These tests are not to be included within the pipeline
final class DurableHNSWCorpusTests: XCTestCase {
/// This is used to skip these tests in the GitLab pipeline
override class var defaultTestSuite: XCTestSuite {
if ProcessInfo.processInfo.environment["SKIP_TESTS"] == "DurableHNSWCorpusTests" {
return XCTestSuite(name: "Empty")
}
return super.defaultTestSuite
}
/// Setting up constants for environment
private let ONE_GB: Int = 1_073_741_824
......
......@@ -24,8 +24,8 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
XCTAssert(corpus.count == 3)
/// Make sure none of our encodings are zero
for c in corpus {
XCTAssertNotEqual(c, corpus.zeroes)
for item in corpus {
XCTAssertNotEqual(item.vector, corpus.zeroes)
}
}
......@@ -60,8 +60,8 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
XCTAssertEqual(corpus.count, 20)
/// Make sure none of our encodings are zero
for c in corpus {
XCTAssertNotEqual(c, corpus.zeroes)
for item in corpus {
XCTAssertNotEqual(item.vector, corpus.zeroes)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment