Skip to content
Snippets Groups Projects
Commit d660d689 authored by Mingchung Xia's avatar Mingchung Xia
Browse files

Test code cleanup

parent 92c1d9a8
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
//
// File.swift
//
//
// Created by Mingchung Xia on 2024-02-14.
//
import Foundation
extension HNSWCorpus {
/// This extension is used for the dictionary operations
public struct DocumentVectorPair {
var untokenizedDocument: String
var vector: [Scalar]
init(untokenizedDocument: String, vector: [Scalar]) {
self.untokenizedDocument = untokenizedDocument
self.vector = vector
}
}
@inlinable
func getUntokenizedDocument(at key: Int) -> String {
if let pair = dictionary[key] {
return pair.untokenizedDocument
} else {
fatalError("Key \(key) not found in HNSW dictionary")
}
}
@inlinable
func getVector(at key: Int) -> [Scalar] {
if let pair = dictionary[key] {
return pair.vector
} else {
fatalError("Key \(key) not found in HNSW dictionary")
}
}
@inlinable
func getDictionary() -> [Int: DocumentVectorPair] {
return dictionary
}
private func addDocumentVectorPair(at key: Int, document: String, vector: [Scalar]) {
dictionary[key] = DocumentVectorPair(untokenizedDocument: document, vector: vector)
}
}
...@@ -14,9 +14,6 @@ final class HNSWTests: XCTestCase { ...@@ -14,9 +14,6 @@ final class HNSWTests: XCTestCase {
"that enable us to train deep learning algorithms to learn like the human brain." "that enable us to train deep learning algorithms to learn like the human brain."
] ]
// let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
// let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
// var corpus = HNSWCorpus(_documentEncoder: encoder)
var corpus = HNSWCorpus(encoding: .glove6B50d) var corpus = HNSWCorpus(encoding: .glove6B50d)
corpus.addUntokenizedDocuments(docs) corpus.addUntokenizedDocuments(docs)
...@@ -57,9 +54,6 @@ final class HNSWTests: XCTestCase { ...@@ -57,9 +54,6 @@ final class HNSWTests: XCTestCase {
"All science is either physics or stamp collecting. - Ernest Rutherford" "All science is either physics or stamp collecting. - Ernest Rutherford"
] ]
// let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
// let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
// var corpus = HNSWCorpus(_documentEncoder: encoder)
var corpus = HNSWCorpus(encoding: .glove6B50d) var corpus = HNSWCorpus(encoding: .glove6B50d)
corpus.addUntokenizedDocuments(twentyQuotes) corpus.addUntokenizedDocuments(twentyQuotes)
...@@ -85,9 +79,6 @@ final class HNSWTests: XCTestCase { ...@@ -85,9 +79,6 @@ final class HNSWTests: XCTestCase {
let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData) let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
// let encoder = ContextFreeEncoder<Double>(source: .glove6B50d)
// let encoder = ContextFreeEncoder<Double>(source: .hnswindex)
// var corpus = HNSWCorpus(_documentEncoder: encoder)
var corpus = HNSWCorpus(encoding: .glove6B50d) var corpus = HNSWCorpus(encoding: .glove6B50d)
for submission in submissions { for submission in submissions {
...@@ -172,7 +163,6 @@ final class HNSWTests: XCTestCase { ...@@ -172,7 +163,6 @@ final class HNSWTests: XCTestCase {
do { do {
print("Attempting to query corpus.encodedDocuments.find()...") print("Attempting to query corpus.encodedDocuments.find()...")
// TODO: Print this as a readable result - reverse encoding?
let queryVector: [Double] = _documentEncoder.encodeToken(query).map { Double($0) } let queryVector: [Double] = _documentEncoder.encodeToken(query).map { Double($0) }
let results = try corpus.encodedDocuments.find(near: queryVector, limit: 8) let results = try corpus.encodedDocuments.find(near: queryVector, limit: 8)
...@@ -223,7 +213,48 @@ final class HNSWTests: XCTestCase { ...@@ -223,7 +213,48 @@ final class HNSWTests: XCTestCase {
do { do {
print("Attempting to query corpus.encodedDocuments.find()...") print("Attempting to query corpus.encodedDocuments.find()...")
// TODO: Print this as a readable result - reverse encoding? let queryVector: [Double] = _documentEncoder.encodeToken(query).map { Double($0) }
let results = try corpus.encodedDocuments.find(near: queryVector, limit: 8)
for result in results {
print(corpus.getUntokenizedDocument(at: result.id))
}
print("Query completed!")
} catch {
print("Error when trying corpus.encodedDocuments.find(): \(error)")
}
}
// TODO: Get HNSWCorpus from memory map
func testSubredditQueryExample() async throws {
guard let submissionsURL = Bundle.module.url(forResource: "Guelph_submissions", withExtension: "zst") else {
fatalError("Failed to find waterloo_submissions.zst in test bundle.")
}
guard let submissionsData = try? Data(contentsOf: submissionsURL) else {
fatalError("Failed to load waterloo_submissions.zst from test bundle.")
}
let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
var corpus = HNSWCorpus(encoder: _documentEncoder)
for submission in submissions {
if let text = submission.selftext {
corpus.addUntokenizedDocument(text)
}
}
let query = "Mr. Goose is a very important figure at the University of Waterloo."
let size = MemoryLayout.size(ofValue: corpus)
print("Approximate memory footprint: \(size) bytes")
do {
print("Attempting to query corpus.encodedDocuments.find()...")
let queryVector: [Double] = _documentEncoder.encodeToken(query).map { Double($0) } let queryVector: [Double] = _documentEncoder.encodeToken(query).map { Double($0) }
let results = try corpus.encodedDocuments.find(near: queryVector, limit: 8) let results = try corpus.encodedDocuments.find(near: queryVector, limit: 8)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment