Skip to content
Snippets Groups Projects
Commit d5d9c573 authored by Mingchung Xia's avatar Mingchung Xia
Browse files

Started on HNSWCorpus

parent e81dc37a
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
......@@ -21,6 +21,10 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
// MARK: Sequence conformance will be done when HNSWCorpus is complete
/*
extension HNSWCorpus: Sequence {
typealias Element = [Scalar]
......@@ -54,3 +58,5 @@ extension HNSWCorpus: Sequence {
return encodedDocuments.index(after: i)
}
}
*/
......@@ -28,6 +28,8 @@ import PriorityHeapAlgorithms
import SimilarityMetric
import HNSWAlgorithm
import HNSWEphemeral
//import HNSWSample
import GameplayKit
class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
......@@ -36,18 +38,27 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
var zeroes: [Scalar]
var count: Int { 0 }
var encodedDocuments: [Int : [Scalar]] = [:] // TODO: This should be replaced by HNSW
// var encodedDocuments: [Int : [Scalar]] = [:]
// MARK: typicalNeighbourhoodSize is unknown
var encodedDocuments: DeterministicSampleVectorIndex = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: 20)
init(_documentEncoder: ContextFreeEncoder<Scalar>) {
self._documentEncoder = _documentEncoder
zeroes = Array(repeating: Scalar(0), count: 384)
}
// TODO: Complete implementation of addUntokenizedDocument
@inlinable
func addUntokenizedDocument(_ document: String) {
fatalError("HNSWCorpus not implemented yet. Get on it.")
encodedDocuments.insertRandom(document)
}
// MARK: HNSW indexes do not support deletion - index must be rebuilt
// The following code is taken from Tests/HNSWTests/HNSWIndexTests.swift
// The test case randomly inserts and randomly queries neighbours.
//
// var index = DeterministicSampleVectorIndex(typicalNeighborhoodSize: 20)
// for _ in 0..<100 {
// index.insertRandom(range: 0...1)
......@@ -65,7 +76,7 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
// TODO: Continue overwriting these structures: this implementation uses the Vector instead of [Double]
public struct DeterministicSampleVectorIndex<Vector: Collection & Codable> where Vector.Element: BinaryFloatingPoint {
public typealias Index = EphemeralVectorIndex<Int, Int, CartesianDistanceMetric<[Double]>, Void>
......@@ -75,6 +86,9 @@ public struct DeterministicSampleVectorIndex<Vector: Collection & Codable> where
base = .init(metric: .init(), config: .unstableDefault(typicalNeighborhoodSize: typicalNeighborhoodSize))
}
private var vectorRNG = DeterministicRandomNumberGenerator(seed: 0)
private var graphRNG = DeterministicRandomNumberGenerator(seed: 1)
public func find(near query: Vector, limit: Int, exact: Bool = false) throws -> [Index.Neighbor] {
if exact {
Array(PriorityHeap(base.vectors.enumerated().map {
......@@ -86,8 +100,39 @@ public struct DeterministicSampleVectorIndex<Vector: Collection & Codable> where
}
}
// Should we be generating random Vector instead of CGPoint? How long is a Vector?
public mutating func generateRandom(range: ClosedRange<Double>) -> Vector {
/*
CGPoint(
x: .random(in: range, using: &vectorRNG),
y: .random(in: range, using: &vectorRNG)
)
*/
}
public mutating func insertRandom(range: ClosedRange<Double>) {
base.insert(generateRandom(range: range) as! [Double], using: &graphRNG)
}
}
struct DeterministicRandomNumberGenerator: RandomNumberGenerator {
private let randomSource: GKMersenneTwisterRandomSource
init(seed: UInt64) {
randomSource = GKMersenneTwisterRandomSource(seed: seed)
}
mutating func next() -> UInt64 {
let upperBits = UInt64(UInt32(bitPattern: Int32(randomSource.nextInt()))) << 32
let lowerBits = UInt64(UInt32(bitPattern: Int32(randomSource.nextInt())))
return upperBits | lowerBits
}
}
public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint{
public func similarity(between someItem: Vector, _ otherItem: Vector) -> Vector.Element {
// Naïve cartesian distance
......@@ -98,4 +143,3 @@ public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityM
return sqrt(squaredSum)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment