Skip to content
Snippets Groups Projects
Commit c1f31cfd authored by Mingchung Xia's avatar Mingchung Xia
Browse files

HNSW Sequence and Collection conformance, document encoding and RNG insertions

parent d1cfedea
No related branches found
No related tags found
1 merge request!13HNSW Implementation with Testcases
Pipeline #110117 failed
......@@ -21,56 +21,42 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
// OTHER DEALINGS IN THE SOFTWARE.
// MARK: Sequence conformance will be done when HNSWCorpus is complete
extension HNSWCorpus: Sequence {
extension HNSWCorpus: Sequence, Collection {
typealias Element = [Scalar]
// Sequence Protocol Requirements
@inlinable
@inlinable
func makeIterator() -> AnyIterator<Element> {
/// DeterministicSampleVectorIndex likely does not have .count or .elementAt implemented
/// This provides a boilerplate template for the protocol conformance
var index = 0
return AnyIterator {
guard index < self.encodedDocuments.count else { return nil } // .count
let element = self.encodedDocuments.elementAt(index) // .elementAt
index += 1
defer { index += 1 }
guard index < self.encodedDocuments.base.vectors.count else { return nil }
let element = self.encodedDocuments.base.vectors[index] // consider using .find
return element
}
}
/*
// Sequence Protocol Requirements
@inlinable
func makeIterator() -> Dictionary<Int, [Scalar]>.Values.Iterator {
return encodedDocuments.values.makeIterator()
}
// Collection Protocol Requirements
@inlinable
var startIndex: Dictionary<Int, [Scalar]>.Index {
return encodedDocuments.startIndex
var startIndex: Int {
return encodedDocuments.base.vectors.startIndex
}
@inlinable
var endIndex: Dictionary<Int, [Scalar]>.Index {
return encodedDocuments.endIndex
var endIndex: Int {
return encodedDocuments.base.vectors.endIndex
}
@inlinable
subscript(position: Dictionary<Int, [Scalar]>.Index) -> [Scalar] {
encodedDocuments.values[position]
subscript(position: Int) -> Element {
return encodedDocuments.base.vectors[position]
}
@inlinable
func index(after i: Dictionary<Int, [Scalar]>.Index) -> Dictionary<Int, [Scalar]>.Index {
return encodedDocuments.index(after: i)
func index(after i: Int) -> Int {
return encodedDocuments.base.vectors.index(after: i)
}
*/
}
......@@ -24,7 +24,6 @@
import Foundation
import PriorityHeapModule
import PriorityHeapAlgorithms
import SimilarityMetric
import HNSWAlgorithm
import HNSWEphemeral
......@@ -33,13 +32,11 @@ import GameplayKit
class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
var _documentEncoder: ContextFreeEncoder<Scalar>
private var _documentEncoder: ContextFreeEncoder<Scalar>
var zeroes: [Scalar]
var count: Int { 0 }
// var encodedDocuments: [Int : [Scalar]] = [:]
/// typicalNeighbourhoodSize is unknown
/// typicalNeighbourhoodSize = 20 is a standard benchmark
var encodedDocuments: DeterministicSampleVectorIndex = DeterministicSampleVectorIndex<[Scalar]>(typicalNeighborhoodSize: 20)
init(_documentEncoder: ContextFreeEncoder<Scalar>) {
......@@ -49,67 +46,53 @@ class HNSWCorpus<Scalar: BinaryFloatingPoint & Codable>: SNLPCorpus {
@inlinable
func addUntokenizedDocument(_ document: String) {
let encodedDocument = _documentEncoder.encodeSentence(document)
encodedDocuments.insert(vector: encodedDocument)
encodedDocuments.insert((_documentEncoder.encodeSentence(document) as! [Scalar]))
}
// HNSW indexes do not support deletion - index must be rebuilt
}
public struct DeterministicSampleVectorIndex<Vector: Collection & Codable> where Vector.Element: BinaryFloatingPoint {
// MARK: Index accepts only [Double]
public typealias Index = EphemeralVectorIndex<Int, Int, CartesianDistanceMetric<[Double]>, Void>
/// EmphermalVectorIndex<Key: BinaryInteger, Level: BinaryInteger, Metric: SimilarityMetric, Metadata>
public typealias Index = EphemeralVectorIndex<Int, Int, CartesianDistanceMetric<Vector>, Void>
public var base: Index
public init(typicalNeighborhoodSize: Int) {
base = .init(metric: .init(), config: .unstableDefault(typicalNeighborhoodSize: typicalNeighborhoodSize))
base = .init(metric: CartesianDistanceMetric<Vector>(), config: .unstableDefault(typicalNeighborhoodSize: typicalNeighborhoodSize))
}
// private var vectorRNG = DeterministicRandomNumberGenerator(seed: 0)
// private var graphRNG = DeterministicRandomNumberGenerator(seed: 1)
private var graphRNG = DeterministicRandomNumberGenerator(seed: 1)
public func find(near query: Vector, limit: Int, exact: Bool = false) throws -> [Index.Neighbor] {
if exact {
// Exact search
Array(PriorityHeap(base.vectors.enumerated().map {
let similarity = base.metric.similarity(between: query as! [Double], $0.element)
return Array(PriorityHeap(base.vectors.enumerated().map {
let similarity = base.metric.similarity(between: query, $0.element)
return NearbyVector(id: $0.offset, vector: $0.element, priority: similarity)
}).descending().prefix(limit))
} else {
// Approximation search
Array(try base.find(near: query as! [Double], limit: limit))
return Array(try base.find(near: query, limit: limit))
}
}
public mutating func insert(vector: Vector) {
// Convert the generic vector to a [Double], which is the required type for `base.insert`
let doubleVector = vector.map { Double($0) }
// Insert the vector into the HNSW graph
base.insert(doubleVector)
}
/*
public mutating func insertRandom(range: ClosedRange<Double>) {
base.insert(generateRandom(range: range) as! [Double], using: &graphRNG)
public mutating func insert(_ vector: Vector) {
let convertedVector: [Double] = vector.map{ Double($0) }
if let metricVector = convertedVector as? CartesianDistanceMetric<Vector>.Vector {
base.insert(metricVector, using: &graphRNG)
} else {
fatalError("Unable to get metric vector")
}
}
*/
}
public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint{
public struct CartesianDistanceMetric<Vector: Collection & Codable>: SimilarityMetric where Vector.Element: BinaryFloatingPoint {
public func similarity(between someItem: Vector, _ otherItem: Vector) -> Vector.Element {
// Naïve cartesian distance
let squaredSum = zip(someItem, otherItem)
.map { (x, y) in (x - y) * (x - y) }
.reduce(0, +)
return sqrt(squaredSum)
}
}
/*
struct DeterministicRandomNumberGenerator: RandomNumberGenerator {
private let randomSource: GKMersenneTwisterRandomSource
......@@ -123,4 +106,3 @@ struct DeterministicRandomNumberGenerator: RandomNumberGenerator {
return upperBits | lowerBits
}
}
*/
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment