Skip to content
Snippets Groups Projects
Commit 4bd282f5 authored by Jim Wallace's avatar Jim Wallace
Browse files

Cleaned up search ergonomics for HNSW

parent 7501b232
No related branches found
No related tags found
1 merge request!14- Improved ergonomics for generic types: SNLPCorpus, SNLPEncoder, InMemoryCorpus
Pipeline #115889 passed with warnings
......@@ -30,9 +30,9 @@ extension EphemeralHNSWCorpus {
/// This extension is used for the dictionary operations
public struct DocumentVectorPair {
var untokenizedDocument: String
var vector: [Scalar]
var vector: [Encoder.Scalar]
init(untokenizedDocument: String, vector: [Scalar]) {
init(untokenizedDocument: String, vector: [Encoder.Scalar]) {
self.untokenizedDocument = untokenizedDocument
self.vector = vector
}
......@@ -48,7 +48,7 @@ extension EphemeralHNSWCorpus {
}
@inlinable
func getVector(at key: Int) -> [Scalar] {
func getVector(at key: Int) -> [Encoder.Scalar] {
if let pair = dictionary[key] {
return pair.vector
} else {
......@@ -61,7 +61,7 @@ extension EphemeralHNSWCorpus {
return dictionary
}
func addDocumentVectorPair(at key: Int, document: String, vector: [Scalar]) {
func addDocumentVectorPair(at key: Int, document: String, vector: [Encoder.Scalar]) {
dictionary[key] = DocumentVectorPair(
untokenizedDocument: document,
vector: vector
......@@ -69,7 +69,7 @@ extension EphemeralHNSWCorpus {
}
}
extension EphemeralHNSWCorpus.DocumentVectorPair: Codable where Scalar: Codable {
extension EphemeralHNSWCorpus.DocumentVectorPair: Codable where Encoder.Scalar: Codable {
enum CodingKeys: String, CodingKey {
case untokenizedDocument
case vector
......@@ -78,7 +78,7 @@ extension EphemeralHNSWCorpus.DocumentVectorPair: Codable where Scalar: Codable
internal init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
untokenizedDocument = try container.decode(String.self, forKey: .untokenizedDocument)
vector = try container.decode([Scalar].self, forKey: .vector)
vector = try container.decode([Encoder.Scalar].self, forKey: .vector)
}
// internal func encode(to encoder: Encoder) throws {
......
......@@ -34,35 +34,32 @@ import Foundation
// MARK: Allow EphemeralHNSWCorpus to simply be used as HNSWCorpus
typealias HNSWCorpus = EphemeralHNSWCorpus
final class EphemeralHNSWCorpus<Scalar: BinaryFloatingPoint & Codable, Encoder: SNLPEncoder, Item: SNLPDataItem>: SNLPCorpus where Encoder.Scalar == Scalar {
final class EphemeralHNSWCorpus<Item: SNLPDataItem,Encoder: SNLPEncoder>: SNLPCorpus where Encoder.Scalar: Codable{
public typealias HNSWDictionary = [Int: DocumentVectorPair]
internal var documentEncoder: Encoder
internal var documents = ContiguousArray<Item>()
internal var encodedDocuments = ContiguousArray<[Scalar]>()
internal var encodedDocuments = ContiguousArray<[Encoder.Scalar]>()
//var zeroes: [Scalar] { _documentEncoder.zeroes }
//var count: Int { encodedDocuments.base.vectors.count }
var index: DeterministicEphemeralVectorIndex<[Scalar]>
var index: DeterministicEphemeralVectorIndex<[Encoder.Scalar]>
// Keeps track of the original document for client code
var dictionary: HNSWDictionary = [:]
// typicalNeighbourhoodSize = 20 is a standard benchmark
init(encoding: ContextFreeEncoder<Scalar>.PreComputedEmbeddings,
typicalNeighborhoodSize: Int = 20) {
documentEncoder = ContextFreeEncoder(source: encoding) as! Encoder
index = DeterministicEphemeralVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
}
// init(encoding: ContextFreeEncoder<Encoder.Scalar>.PreComputedEmbeddings,
// typicalNeighborhoodSize: Int = 20) {
// documentEncoder = ContextFreeEncoder(source: encoding) as! Encoder
// index = DeterministicEphemeralVectorIndex<[Encoder.Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
// }
init(encoder: Encoder, typicalNeighborhoodSize: Int = 20) {
init(encoder: Encoder = Encoder(), typicalNeighborhoodSize: Int = 20) {
documentEncoder = encoder
index = DeterministicEphemeralVectorIndex<[Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
index = DeterministicEphemeralVectorIndex<[Encoder.Scalar]>(typicalNeighborhoodSize: typicalNeighborhoodSize)
}
// // Decodable conformance
......@@ -92,6 +89,10 @@ final class EphemeralHNSWCorpus<Scalar: BinaryFloatingPoint & Codable, Encoder:
}
func searchFor(_ query: String) -> [Item] {
return []
let queryVector = documentEncoder.encodeToken(query)
let results = try! index.find(near: queryVector, limit: 8)
return results.map{ documents[$0.id] }
}
}
......@@ -15,7 +15,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
"that enable us to train deep learning algorithms to learn like the human brain."
]
var corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
var corpus = HNSWCorpus<String,ContextFreeEncoder<Double>>()
corpus.addUntokenizedDocuments(docs)
XCTAssert(corpus.count == 3)
......@@ -51,7 +51,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
"All science is either physics or stamp collecting. - Ernest Rutherford"
]
var corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
var corpus = HNSWCorpus<String,ContextFreeEncoder<Double>>()
corpus.addUntokenizedDocuments(twentyQuotes)
XCTAssertEqual(corpus.count, 20)
......@@ -72,7 +72,7 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
let corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
let corpus = HNSWCorpus<String,ContextFreeEncoder<Double>>()
for submission in submissions {
if let text = submission.selftext {
......@@ -98,20 +98,22 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
let query = "I like to read about new technology and artificial intelligence"
let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
var corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoding: .glove6B50d)
//let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
var corpus = HNSWCorpus<String,ContextFreeEncoder<Double>>()
corpus.addUntokenizedDocuments(docs)
do {
let queryVector: [Double] = _documentEncoder.encodeToken(query).map { Double($0) }
let results = try corpus.index.find(near: queryVector, limit: 8)
//do {
//let queryVector: [Double] = _documentEncoder.encodeToken(query).map { Double($0) }
//let results = try corpus.index.find(near: queryVector, limit: 8)
let results = corpus.searchFor(query)
for result in results {
print(corpus.getUntokenizedDocument(at: result.id))
print(result)
}
} catch {
print("Error when trying corpus.encodedDocuments.find(): \(error)")
}
//} catch {
// print("Error when trying corpus.encodedDocuments.find(): \(error)")
//}
}
func testQueryLargeCorpus() async throws {
......@@ -137,23 +139,19 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
"One, remember to look up at the stars and not down at your feet. Two, never give up work. Work gives you meaning and purpose and life is empty without it. Three, if you are lucky enough to find love, remember it is there and don't throw it away. - Stephen Hawking",
"All science is either physics or stamp collecting. - Ernest Rutherford"
]
let query = "I love Albert Einstein!"
let documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
var corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoder: documentEncoder)
var corpus = HNSWCorpus<String,ContextFreeEncoder<Double>>()
corpus.addUntokenizedDocuments(docs)
do {
let queryVector: [Double] = documentEncoder.encodeToken(query).map { Double($0) }
let results = try corpus.index.find(near: queryVector, limit: 8)
for result in results {
print(corpus.getUntokenizedDocument(at: result.id))
}
} catch {
print("Error when trying corpus.encodedDocuments.find(): \(error)")
let results = corpus.searchFor(query)
for result in results {
print(result)
}
}
func testQueryGuephSubredditCorpus() async throws {
......@@ -166,8 +164,8 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
let (submissions, _ ): ([Submission],[Data]) = try await loadFromRedditArchive(submissionsData)
let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
let corpus = HNSWCorpus<Double,ContextFreeEncoder,String>(encoder: _documentEncoder)
//let _documentEncoder = ContextFreeEncoder<Double>(source: .glove6B50d)
let corpus = HNSWCorpus<String,ContextFreeEncoder<Double>>()
for submission in submissions {
if let text = submission.selftext {
......@@ -177,15 +175,10 @@ final class EphemeralHNSWCorpusTests: XCTestCase {
let query = "Mr. Goose is a very important figure at the University of Waterloo."
do {
let queryVector: [Double] = _documentEncoder.encodeToken(query).map { Double($0) }
let results = try corpus.index.find(near: queryVector, limit: 8)
for result in results {
print(corpus.getUntokenizedDocument(at: result.id))
}
} catch {
print("Error when trying corpus.encodedDocuments.find(): \(error)")
let results = corpus.searchFor(query)
for result in results {
print(result)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment