Skip to content
Snippets Groups Projects
Commit 75099c56 authored by Mingchung Xia's avatar Mingchung Xia
Browse files

Implementing adding of untokenized documents to hnsw and sequence boilerplate

parent d5d9c573
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
Pipeline #109792 failed
...@@ -23,13 +23,27 @@ ...@@ -23,13 +23,27 @@
// MARK: Sequence conformance will be done when HNSWCorpus is complete // MARK: Sequence conformance will be done when HNSWCorpus is complete
/*
extension HNSWCorpus: Sequence { extension HNSWCorpus: Sequence {
typealias Element = [Scalar] typealias Element = [Scalar]
// Sequence Protocol Requirements
@inlinable
func makeIterator() -> AnyIterator<Element> {
/// DeterministicSampleVectorIndex likely does not have .count or .elementAt implemented
/// This provides a boilerplate template for the protocol conformance
var index = 0
return AnyIterator {
guard index < self.encodedDocuments.count else { return nil } // .count
let element = self.encodedDocuments.elementAt(index) // .elementAt
index += 1
return element
}
}
/*
// Sequence Protocol Requirements // Sequence Protocol Requirements
@inlinable @inlinable
func makeIterator() -> Dictionary<Int, [Scalar]>.Values.Iterator { func makeIterator() -> Dictionary<Int, [Scalar]>.Values.Iterator {
...@@ -57,6 +71,6 @@ extension HNSWCorpus: Sequence { ...@@ -57,6 +71,6 @@ extension HNSWCorpus: Sequence {
func index(after i: Dictionary<Int, [Scalar]>.Index) -> Dictionary<Int, [Scalar]>.Index { func index(after i: Dictionary<Int, [Scalar]>.Index) -> Dictionary<Int, [Scalar]>.Index {
return encodedDocuments.index(after: i) return encodedDocuments.index(after: i)
} }
*/
} }
*/
...@@ -28,7 +28,6 @@ import PriorityHeapAlgorithms ...@@ -28,7 +28,6 @@ import PriorityHeapAlgorithms
import SimilarityMetric import SimilarityMetric
import HNSWAlgorithm import HNSWAlgorithm
import HNSWEphemeral import HNSWEphemeral
//import HNSWSample
import GameplayKit import GameplayKit
...@@ -40,7 +39,7 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { ...@@ -40,7 +39,7 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
// var encodedDocuments: [Int : [Scalar]] = [:] // var encodedDocuments: [Int : [Scalar]] = [:]
// MARK: typicalNeighbourhoodSize is unknown /// typicalNeighbourhoodSize is unknown
var encodedDocuments: DeterministicSampleVectorIndex = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: 20) var encodedDocuments: DeterministicSampleVectorIndex = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: 20)
init(_documentEncoder: ContextFreeEncoder<Scalar>) { init(_documentEncoder: ContextFreeEncoder<Scalar>) {
...@@ -48,37 +47,19 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus { ...@@ -48,37 +47,19 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
zeroes = Array(repeating: Scalar(0), count: 384) zeroes = Array(repeating: Scalar(0), count: 384)
} }
// TODO: Complete implementation of addUntokenizedDocument
@inlinable @inlinable
func addUntokenizedDocument(_ document: String) { func addUntokenizedDocument(_ document: String) {
encodedDocuments.insertRandom(document) let encodedDocument = _documentEncoder.encodeSentence(document)
encodedDocuments.insert(vector: encodedDocument)
} }
// MARK: HNSW indexes do not support deletion - index must be rebuilt
// The following code is taken from Tests/HNSWTests/HNSWIndexTests.swift // HNSW indexes do not support deletion - index must be rebuilt
// The test case randomly inserts and randomly queries neighbours.
//
// var index = DeterministicSampleVectorIndex(typicalNeighborhoodSize: 20)
// for _ in 0..<100 {
// index.insertRandom(range: 0...1)
// }
//
// for i in 0..<10 {
// let sample = index.generateRandom(range: 0...1)
// print("iter \(i): \(sample)")
// let hnswResults = try! index.find(near: sample, limit: 10)
// let exactResult = try! index.find(near: sample, limit: 1, exact: true)
// XCTAssert(exactResult.contains(where: { $0.id == hnswResults[0].id }))
// }
} }
// TODO: Continue overwriting these structures: this implementation uses the Vector instead of [Double]
public struct DeterministicSampleVectorIndex<Vector: Collection & Codable> where Vector.Element: BinaryFloatingPoint { public struct DeterministicSampleVectorIndex<Vector: Collection & Codable> where Vector.Element: BinaryFloatingPoint {
// MARK: Index accepts only [Double]
public typealias Index = EphemeralVectorIndex<Int, Int, CartesianDistanceMetric<[Double]>, Void> public typealias Index = EphemeralVectorIndex<Int, Int, CartesianDistanceMetric<[Double]>, Void>
public var base: Index public var base: Index
...@@ -86,38 +67,49 @@ public struct DeterministicSampleVectorIndex<Vector: Collection & Codable> where ...@@ -86,38 +67,49 @@ public struct DeterministicSampleVectorIndex<Vector: Collection & Codable> where
base = .init(metric: .init(), config: .unstableDefault(typicalNeighborhoodSize: typicalNeighborhoodSize)) base = .init(metric: .init(), config: .unstableDefault(typicalNeighborhoodSize: typicalNeighborhoodSize))
} }
private var vectorRNG = DeterministicRandomNumberGenerator(seed: 0) // private var vectorRNG = DeterministicRandomNumberGenerator(seed: 0)
private var graphRNG = DeterministicRandomNumberGenerator(seed: 1) // private var graphRNG = DeterministicRandomNumberGenerator(seed: 1)
public func find(near query: Vector, limit: Int, exact: Bool = false) throws -> [Index.Neighbor] { public func find(near query: Vector, limit: Int, exact: Bool = false) throws -> [Index.Neighbor] {
if exact { if exact {
// Exact search
Array(PriorityHeap(base.vectors.enumerated().map { Array(PriorityHeap(base.vectors.enumerated().map {
let similarity = base.metric.similarity(between: query as! [Double], $0.element) let similarity = base.metric.similarity(between: query as! [Double], $0.element)
return NearbyVector(id: $0.offset, vector: $0.element, priority: similarity) return NearbyVector(id: $0.offset, vector: $0.element, priority: similarity)
}).descending().prefix(limit)) }).descending().prefix(limit))
} else { } else {
// Approximation search
Array(try base.find(near: query as! [Double], limit: limit)) Array(try base.find(near: query as! [Double], limit: limit))
} }
} }
// Should we be generating random Vector instead of CGPoint? How long is a Vector? public mutating func insert(vector: Vector) {
// Convert the generic vector to a [Double], which is the required type for `base.insert`
public mutating func generateRandom(range: ClosedRange<Double>) -> Vector { let doubleVector = vector.map { Double($0) }
/*
CGPoint( // Insert the vector into the HNSW graph
x: .random(in: range, using: &vectorRNG), base.insert(doubleVector)
y: .random(in: range, using: &vectorRNG)
)
*/
} }
/*
public mutating func insertRandom(range: ClosedRange<Double>) { public mutating func insertRandom(range: ClosedRange<Double>) {
base.insert(generateRandom(range: range) as! [Double], using: &graphRNG) base.insert(generateRandom(range: range) as! [Double], using: &graphRNG)
} }
*/
} }
public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint{
public func similarity(between someItem: Vector, _ otherItem: Vector) -> Vector.Element {
// Naïve cartesian distance
let squaredSum = zip(someItem, otherItem)
.map { (x, y) in (x - y) * (x - y) }
.reduce(0, +)
return sqrt(squaredSum)
}
}
/*
struct DeterministicRandomNumberGenerator: RandomNumberGenerator { struct DeterministicRandomNumberGenerator: RandomNumberGenerator {
private let randomSource: GKMersenneTwisterRandomSource private let randomSource: GKMersenneTwisterRandomSource
...@@ -131,15 +123,4 @@ struct DeterministicRandomNumberGenerator: RandomNumberGenerator { ...@@ -131,15 +123,4 @@ struct DeterministicRandomNumberGenerator: RandomNumberGenerator {
return upperBits | lowerBits return upperBits | lowerBits
} }
} }
*/
public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint{
public func similarity(between someItem: Vector, _ otherItem: Vector) -> Vector.Element {
// Naïve cartesian distance
let squaredSum = zip(someItem, otherItem)
.map { (x, y) in (x - y) * (x - y) }
.reduce(0, +)
return sqrt(squaredSum)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment